"""
Candidates
----------
"""
import sys
from math import ceil
from math import floor
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
import numpy as np
from anndata import AnnData # type: ignore
import metacells.parameters as pr
import metacells.utilities as ut
if "sphinx" not in sys.argv[0]:
import metacells.extensions as xt # type: ignore
__all__ = [
"compute_candidate_metacells",
"choose_seeds",
"optimize_partitions",
"score_partitions",
]
@ut.logged()
def _seeds_count_for(
*,
total_umis: float,
total_size: float,
min_metacell_size: float,
max_metacell_size: float,
min_metacell_umis: float,
max_metacell_umis: float,
) -> int:
min_seeds_count_by_umis = int(ceil(total_umis / max_metacell_umis))
ut.log_calc("min_seeds_count_by_umis", min_seeds_count_by_umis)
max_seeds_count_by_umis = int(ceil(total_umis / min_metacell_umis))
ut.log_calc("max_seeds_count_by_umis", max_seeds_count_by_umis)
min_seeds_count_by_size = int(ceil(total_size / max_metacell_size))
ut.log_calc("min_seeds_count_by_size", min_seeds_count_by_size)
max_seeds_count_by_size = int(ceil(total_size / min_metacell_size))
ut.log_calc("max_seeds_count_by_size", max_seeds_count_by_size)
if max_seeds_count_by_size < min_seeds_count_by_umis:
seeds_count = int(round((max_seeds_count_by_size + min_seeds_count_by_umis) / 2))
elif max_seeds_count_by_umis < min_seeds_count_by_size:
seeds_count = int(round((max_seeds_count_by_umis + min_seeds_count_by_size) / 2))
else:
min_seeds_count = max(min_seeds_count_by_size, min_seeds_count_by_umis)
max_seeds_count = min(max_seeds_count_by_size, max_seeds_count_by_umis)
seeds_count = int(round((min_seeds_count + max_seeds_count) / 2))
ut.log_return("seeds_count", seeds_count)
return seeds_count
def _reduce_communities(
outgoing_edge_weights: ut.CompressedMatrix,
incoming_edge_weights: ut.CompressedMatrix,
community_of_nodes: ut.NumpyVector,
hot_communities: List[int],
target_metacell_umis: float,
node_umis: ut.NumpyVector,
min_metacell_umis: float,
max_metacell_umis: float,
target_metacell_size: float,
min_metacell_size: float,
max_metacell_size: float,
max_split_min_cut_strength: float,
min_cut_seed_cells: int,
must_complete_cover: bool,
min_seed_size_quantile: float,
max_seed_size_quantile: float,
random_seed: int,
cooldown_pass: float,
cooldown_node: float,
cooldown_phase: float,
increase_phase: float,
cold_temperature: float,
atomic_candidates: Set[Tuple[int, ...]],
) -> Tuple[float, float]:
np.random.seed(random_seed)
while True:
ut.log_calc("cold_temperature", cold_temperature)
score = _optimize_partitions(
outgoing_edge_weights=outgoing_edge_weights,
incoming_edge_weights=incoming_edge_weights,
community_of_nodes=community_of_nodes,
hot_communities=hot_communities,
node_umis=node_umis,
low_partition_umis=min_metacell_umis,
target_partition_umis=target_metacell_umis,
high_partition_umis=max_metacell_umis,
low_partition_size=min_metacell_size,
target_partition_size=target_metacell_size,
high_partition_size=max_metacell_size,
random_seed=random_seed,
cooldown_pass=cooldown_pass,
cooldown_node=cooldown_node,
cold_temperature=cold_temperature,
)
cold_temperature = cold_temperature * (1 - cooldown_phase)
max_metacell_size *= increase_phase
ut.log_calc("communities", community_of_nodes, formatter=ut.groups_description)
ut.log_calc("max_metacell_size", max_metacell_size)
split_communities_count, cut_communities_count, hot_communities = _cut_split_communities(
outgoing_edge_weights=outgoing_edge_weights,
community_of_nodes=community_of_nodes,
target_metacell_umis=target_metacell_umis,
node_umis=node_umis,
min_metacell_umis=min_metacell_umis,
max_metacell_umis=max_metacell_umis,
target_metacell_size=target_metacell_size,
min_metacell_size=min_metacell_size,
max_metacell_size=max_metacell_size,
max_split_min_cut_strength=max_split_min_cut_strength,
min_cut_seed_cells=min_cut_seed_cells,
must_complete_cover=must_complete_cover,
atomic_candidates=atomic_candidates,
)
if split_communities_count + cut_communities_count == 0:
return cold_temperature, score
max_seeds_count = np.max(community_of_nodes) + 1
if cut_communities_count > 0:
hot_communities.append(max_seeds_count)
max_seeds_count += 1
_choose_seeds(
outgoing_edge_weights=outgoing_edge_weights,
incoming_edge_weights=incoming_edge_weights,
seed_of_cells=community_of_nodes,
max_seeds_count=max_seeds_count,
min_seed_size_quantile=min_seed_size_quantile,
max_seed_size_quantile=max_seed_size_quantile,
random_seed=random_seed,
)
def _cut_split_communities( # pylint: disable=too-many-branches,too-many-statements
*,
outgoing_edge_weights: ut.CompressedMatrix,
community_of_nodes: ut.NumpyVector,
target_metacell_umis: float,
node_umis: ut.NumpyVector,
min_metacell_umis: float,
max_metacell_umis: float,
target_metacell_size: float,
min_metacell_size: float,
max_metacell_size: float,
max_split_min_cut_strength: float,
min_cut_seed_cells: int,
must_complete_cover: bool,
atomic_candidates: Set[Tuple[int, ...]],
) -> Tuple[int, int, List[int]]:
communities_count = np.max(community_of_nodes) + 1
assert communities_count > 0
split_communities_count = 0
cut_communities_count = 0
next_new_community_index = communities_count
hot_communities = []
for community_index in range(communities_count):
community_mask = community_of_nodes == community_index
community_size = np.sum(community_mask)
community_umis = np.sum(node_umis[community_mask])
if (community_umis > max_metacell_umis and community_size >= 3 * min_metacell_size) or (
community_size > max_metacell_size and community_umis >= 3 * min_metacell_umis
):
community_indices = np.where(community_mask)[0]
split_parts_count = max(
2,
min(
int(round(community_umis / target_metacell_umis)), int(round(community_size / target_metacell_size))
),
)
ut.logger().debug(
"community: %s size: %s umis: %s is too large, split into %s",
community_index,
community_size,
community_umis,
split_parts_count,
)
all_parts_exist = False
while not all_parts_exist:
split_parts_assignment = np.random.randint( # type: ignore
split_parts_count, size=len(community_indices)
)
all_parts_exist = True
for split_index in range(split_parts_count):
if not np.any(split_parts_assignment == split_index):
all_parts_exist = False
break
community_of_nodes[community_indices] = -1
for split_index in range(split_parts_count):
split_part_mask = split_parts_assignment == split_index
assert np.any(split_part_mask)
split_part_indices = community_indices[split_part_mask]
if split_index == 0:
hot_communities.append(community_index)
community_of_nodes[split_part_indices[np.random.randint(len(split_part_indices))]] = community_index
else:
hot_communities.append(next_new_community_index)
community_of_nodes[
split_part_indices[np.random.randint(len(split_part_indices))]
] = next_new_community_index
next_new_community_index += 1
split_communities_count += 1
continue
if split_communities_count == 0:
for community_index in range(communities_count):
community_mask = community_of_nodes == community_index
community_indices_key = tuple(np.where(community_mask)[0])
if community_indices_key in atomic_candidates:
continue
action = _min_cut_community(
outgoing_edge_weights=outgoing_edge_weights,
community_of_nodes=community_of_nodes,
cut_community_index=community_index,
max_split_min_cut_strength=max_split_min_cut_strength,
min_cut_seed_cells=min_cut_seed_cells,
must_complete_cover=must_complete_cover,
new_community_index=next_new_community_index,
)
if action == "unchanged":
atomic_candidates.add(community_indices_key)
continue
hot_communities.append(community_index)
ut.logger().debug(
"community: %s nodes: %s size: %s was %s",
community_index,
np.sum(community_mask),
community_umis,
action,
)
if action == "split":
hot_communities.append(next_new_community_index)
split_communities_count += 1
next_new_community_index += 1
else:
cut_communities_count += 1
if split_communities_count + cut_communities_count > 0:
ut.logger().debug("old communities: %s", communities_count)
ut.logger().debug("cut communities: %s", cut_communities_count)
ut.logger().debug("split communities: %s", split_communities_count)
ut.logger().debug("hot communities: %s", len(hot_communities))
ut.logger().debug("total communities: %s", communities_count + split_communities_count)
else:
ut.logger().debug("no communities were cut or split")
return (split_communities_count, cut_communities_count, hot_communities)
def _min_cut_community(
outgoing_edge_weights: ut.CompressedMatrix,
community_of_nodes: ut.NumpyVector,
cut_community_index: int,
max_split_min_cut_strength: float,
min_cut_seed_cells: int,
must_complete_cover: bool,
new_community_index: int,
) -> str:
community_mask = community_of_nodes == cut_community_index
if np.sum(community_mask) < 2:
return "unchanged"
community_edge_weights = outgoing_edge_weights[community_mask, :][:, community_mask]
community_edge_weights += community_edge_weights.T
cut, cut_strength = ut.min_cut(community_edge_weights)
if cut_strength is None:
return "unchanged"
if cut_strength == 0:
community_indices = np.where(community_mask)[0]
if len(cut.partition[0]) < len(cut.partition[1]):
small_partition = 0
else:
small_partition = 1
small_community_indices = community_indices[cut.partition[small_partition]]
if cut_strength > max_split_min_cut_strength:
return "unchanged"
ut.logger().debug(
"min cut community: %s partitions: %s + %s = %s strength: %s",
cut_community_index,
len(cut.partition[0]),
len(cut.partition[1]),
community_edge_weights.shape[0],
cut_strength,
)
community_indices = np.where(community_mask)[0]
if len(cut.partition[0]) < len(cut.partition[1]):
small_partition = 0
else:
small_partition = 1
if len(cut.partition[small_partition]) >= min_cut_seed_cells:
second_partition_indices = community_indices[cut.partition[1]]
community_of_nodes[second_partition_indices] = new_community_index
return "split"
if must_complete_cover and cut_strength > 0:
ut.logger().debug("give up on small cut: %s", len(cut.partition[small_partition]))
return "unchanged"
small_community_indices = community_indices[cut.partition[small_partition]]
community_of_nodes[small_community_indices] = -1
return "cut"
def _find_small_communities(
*,
community_of_nodes: ut.NumpyVector,
node_umis: ut.NumpyVector,
min_metacell_umis: float,
max_metacell_umis: float,
min_metacell_size: float,
max_metacell_size: float,
lowest_metacell_size: float,
) -> Tuple[Set[int], int, float]:
communities_count = np.max(community_of_nodes) + 1
assert communities_count > 0
small_communities: Set[int] = set()
small_nodes_count = 0
small_nodes_umis = 0
for community_index in range(communities_count):
community_mask = community_of_nodes == community_index
community_size = np.sum(community_mask)
community_umis = np.sum(node_umis[community_mask])
if community_size < lowest_metacell_size:
ut.logger().debug("community: %s size: %s is tiny", community_index, community_size)
elif community_size < min_metacell_size and 2 * community_umis < max_metacell_umis:
ut.logger().debug(
"community: %s umis: %s and size: %s is too few", community_index, community_umis, community_size
)
elif community_umis < min_metacell_umis and 2 * community_size < max_metacell_size:
ut.logger().debug(
"community: %s umis: %s and size: %s is too small", community_index, community_umis, community_size
)
else:
continue
small_communities.add(community_index)
small_nodes_count += community_size
small_nodes_umis += community_umis
return (small_communities, small_nodes_count, small_nodes_umis)
def _cancel_communities(community_of_nodes: ut.NumpyVector, cancelled_communities: Set[int]) -> int:
communities_count = np.max(community_of_nodes) + 1
kept_communities_count = 0
for community_index in range(communities_count):
if community_index in cancelled_communities:
community_of_nodes[community_of_nodes == community_index] = -1
continue
if community_index > kept_communities_count:
community_of_nodes[community_of_nodes == community_index] = kept_communities_count
kept_communities_count += 1
assert kept_communities_count == communities_count - len(cancelled_communities)
return kept_communities_count
[docs]
@ut.logged()
@ut.timed_call()
def choose_seeds(
*,
edge_weights: ut.CompressedMatrix,
seed_of_cells: Optional[ut.NumpyVector] = None,
max_seeds_count: int,
min_seed_size_quantile: float = pr.min_seed_size_quantile,
max_seed_size_quantile: float = pr.max_seed_size_quantile,
random_seed: int,
) -> ut.NumpyVector:
"""
Choose initial assignment of cells to seeds based on the ``edge_weights``.
Returns a vector assigning each node (cell) to a seed (initial community).
If ``seed_of_cells`` is specified, it is expected to contain a vector of partial seeds. Only
cells which have a negative seed will be assigned a new seed. New seeds will be created so that
the total number of seeds will not exceed ``max_seeds_count``. The ``seed_of_cells`` will be
modified in-place and returned.
Otherwise, a new vector is created, initialized with ``-1`` (that is, no seed) for all nodes,
filled as above, and returned.
**Computation Parameters**
1. We compute for each candidate node the number of nodes it is connected to (by an outgoing
edge).
2. We pick as a seed a random node whose number of connected nodes ("seed size") quantile is at
least ``min_seed_size_quantile`` and at most ``max_seed_size_quantile``. This ensures we pick
seeds that aren't too small or too large to get a good coverage of the population with a low
number of seeds.
3. We assign each of the connected nodes to their seed, and discount them from the number of
connected nodes of the remaining unassigned nodes.
4. We repeat this until we reach the target number of seeds.
"""
size = edge_weights.shape[0]
outgoing_edge_weights = ut.mustbe_compressed_matrix(edge_weights)
assert ut.is_layout(outgoing_edge_weights, "row_major")
incoming_edge_weights = ut.mustbe_compressed_matrix(ut.to_layout(outgoing_edge_weights, layout="column_major"))
assert ut.is_layout(incoming_edge_weights, "column_major")
if seed_of_cells is None:
seed_of_cells = np.full(size, -1, dtype="int32")
else:
assert seed_of_cells.dtype == "int32"
assert outgoing_edge_weights.shape == incoming_edge_weights.shape == (len(seed_of_cells), len(seed_of_cells))
_choose_seeds(
outgoing_edge_weights=outgoing_edge_weights,
incoming_edge_weights=incoming_edge_weights,
seed_of_cells=seed_of_cells,
max_seeds_count=max_seeds_count,
min_seed_size_quantile=min_seed_size_quantile,
max_seed_size_quantile=max_seed_size_quantile,
random_seed=random_seed,
)
return seed_of_cells
@ut.logged()
@ut.timed_call()
def _choose_seeds(
*,
outgoing_edge_weights: ut.CompressedMatrix,
incoming_edge_weights: ut.CompressedMatrix,
seed_of_cells: ut.NumpyVector,
max_seeds_count: int,
min_seed_size_quantile: float,
max_seed_size_quantile: float,
random_seed: int,
) -> int:
ut.log_calc("partial seeds", seed_of_cells, formatter=ut.groups_description)
seeds_count = xt.choose_seeds(
outgoing_edge_weights.data,
outgoing_edge_weights.indices,
outgoing_edge_weights.indptr,
incoming_edge_weights.data,
incoming_edge_weights.indices,
incoming_edge_weights.indptr,
random_seed,
max_seeds_count,
min_seed_size_quantile,
max_seed_size_quantile,
seed_of_cells,
)
ut.log_calc("chosen seeds count", seeds_count)
ut.log_calc("chosen seeds", seed_of_cells, formatter=ut.groups_description)
assert np.min(seed_of_cells) == 0
return seeds_count
[docs]
@ut.logged()
@ut.timed_call()
def optimize_partitions(
*,
edge_weights: ut.CompressedMatrix,
community_of_nodes: ut.NumpyVector,
node_umis: ut.NumpyVector,
low_partition_umis: int,
target_partition_umis: int,
high_partition_umis: int,
low_partition_size: int,
target_partition_size: int,
high_partition_size: int,
cooldown_pass: float = pr.cooldown_pass,
cooldown_node: float = pr.cooldown_node,
random_seed: int,
) -> float:
"""
Optimize partition to candidate metacells (communities) using the ``edge_weights``.
Returns the score of the optimized partition.
This modifies the ``community_of_nodes`` in-place.
The goal is to minimize the "stability" goal function which is defined to be the ratio between
(1) the probability that, selecting a random node and either a random outgoing edge or a random
incoming edge (biased by their weights), that the node connected to by that edge is in the same
community (metacell) and (2) the probability that a random edge would lead to this same
community (the fraction of its number of nodes out of the total).
To maximize this, we repeatedly pass on a randomized permutation of the nodes, and for each
node, move it to a random "better" community. When deciding if a community is better, we
consider both (1) just the "local" product of the sum of the weights of incoming and outgoing edges
between the node and the current and candidate communities and (2) the effect on the "global" goal
function (considering the impact on this product for all other nodes connected to the current
node).
We define a notion of ``temperature`` (initially, 1 - ``cooldown_pass``, default:
{cooldown_pass}) and we give a weight of ``temperature`` to the local score and
(1 - ``temperature``) to the global score. When we move to the next node, we multiply the
temperature by 1 - ``cooldown_pass``. If we did not move the node, we multiply its temperature
by ``cooldown_node`` (default: {cooldown_node}). We skip looking at nodes which are colder from
the global temperature to accelerate the algorithm. If we don't move any node, we reduce the
global temperature below that of any cold node; if there are no such nodes, we reduce it to zero
to perform a final hill-climbing phase.
This simulated-annealing-like behavior helps the algorithm to escape local maximums, although of
course no claim is made of achieving the global maximum of the goal function.
"""
outgoing_edge_weights = ut.mustbe_compressed_matrix(edge_weights)
assert ut.is_layout(outgoing_edge_weights, "row_major")
assert 0 < low_partition_size < target_partition_size < high_partition_size
assert 0 < low_partition_umis < target_partition_umis < high_partition_umis
incoming_edge_weights = ut.mustbe_compressed_matrix(ut.to_layout(outgoing_edge_weights, layout="column_major"))
communities_count = np.max(community_of_nodes) + 1
assert communities_count > 0
assert ut.is_layout(incoming_edge_weights, "column_major")
return _optimize_partitions(
outgoing_edge_weights=outgoing_edge_weights,
incoming_edge_weights=incoming_edge_weights,
random_seed=random_seed,
node_umis=node_umis,
target_partition_umis=target_partition_umis,
low_partition_umis=low_partition_umis,
high_partition_umis=high_partition_umis,
target_partition_size=target_partition_size,
low_partition_size=low_partition_size,
high_partition_size=high_partition_size,
cooldown_pass=cooldown_pass,
cooldown_node=cooldown_node,
community_of_nodes=community_of_nodes,
hot_communities=list(range(communities_count)),
cold_temperature=cooldown_pass,
)
@ut.logged()
@ut.timed_call()
def _optimize_partitions(
*,
outgoing_edge_weights: ut.CompressedMatrix,
incoming_edge_weights: ut.CompressedMatrix,
community_of_nodes: ut.NumpyVector,
hot_communities: List[int],
node_umis: ut.NumpyVector,
low_partition_umis: float,
target_partition_umis: float,
high_partition_umis: float,
target_partition_size: float,
low_partition_size: float,
high_partition_size: float,
cooldown_pass: float,
cooldown_node: float,
cold_temperature: float,
random_seed: int,
) -> float:
assert community_of_nodes.dtype == "int32"
assert node_umis.dtype == "float32"
communities_count = np.max(community_of_nodes) + 1
hot_communities_mask = np.zeros(max(communities_count, 2), dtype="int8")
hot_communities_mask[hot_communities] = 1
score = xt.optimize_partitions(
outgoing_edge_weights.data,
outgoing_edge_weights.indices,
outgoing_edge_weights.indptr,
incoming_edge_weights.data,
incoming_edge_weights.indices,
incoming_edge_weights.indptr,
random_seed,
node_umis,
low_partition_umis,
target_partition_umis,
high_partition_umis,
low_partition_size,
target_partition_size,
high_partition_size,
cooldown_pass,
cooldown_node,
community_of_nodes,
hot_communities_mask,
cold_temperature,
)
ut.log_calc("score", score)
ut.log_calc("partitions", community_of_nodes, formatter=ut.groups_description)
assert np.min(community_of_nodes) == 0
return score
[docs]
@ut.logged()
@ut.timed_call()
def score_partitions(
*,
node_umis: ut.NumpyVector,
low_partition_umis: float,
target_partition_umis: float,
high_partition_umis: float,
low_partition_size: int,
target_partition_size: int,
high_partition_size: int,
edge_weights: ut.CompressedMatrix,
partition_of_nodes: ut.NumpyVector,
temperature: float,
with_orphans: bool = True,
) -> None:
"""
Compute the "stability" the "stability" goal function which is defined to be the ratio between
(1) the probability that, selecting a random node and either a random outgoing edge or a random
incoming edge (biased by their weights), that the node connected to by that edge is in the same
community (metacell) and (2) the probability that a random edge would lead to this same
community (the fraction of its number of nodes out of the total).
If ``with_orphans`` is True (the default), outlier nodes are included in the computation. In
general we add 1e-6 to the product of the incoming and outgoing weights so we can safely log it
for efficient computation; thus orphans are given a very small (non-zero) weight so the overall
score is not zeroed even when including them.
"""
assert partition_of_nodes.dtype == "int32"
assert node_umis.dtype == "float32"
outgoing_edge_weights = ut.mustbe_compressed_matrix(edge_weights)
assert ut.is_layout(outgoing_edge_weights, "row_major")
incoming_edge_weights = ut.mustbe_compressed_matrix(ut.to_layout(outgoing_edge_weights, layout="column_major"))
assert ut.is_layout(incoming_edge_weights, "column_major")
hot_communities_mask = np.zeros(np.max(partition_of_nodes) + 1, dtype="int8")
with ut.unfrozen(partition_of_nodes):
with ut.timed_step(".score"):
score = xt.score_partitions(
outgoing_edge_weights.data,
outgoing_edge_weights.indices,
outgoing_edge_weights.indptr,
incoming_edge_weights.data,
incoming_edge_weights.indices,
incoming_edge_weights.indptr,
node_umis,
low_partition_umis,
target_partition_umis,
high_partition_umis,
low_partition_size,
target_partition_size,
high_partition_size,
temperature,
partition_of_nodes,
hot_communities_mask,
with_orphans,
)
ut.log_calc("score", score)
return score