Source code for minterpy.core.tree

"""
This module contains the `MultiIndexTree` class.

The `MultiIndexTree` class encapsulates all the components
(organized in a tree-like structure) necessary to perform a multidimensional
(multivariate) divided difference scheme (DDS).
This scheme is used to transform polynomial coefficients
from the Lagrange basis to the Newton basis.

For more details, see :doc:`/fundamentals/dds`.

----

"""
from __future__ import annotations

from typing import TYPE_CHECKING

from minterpy.dds import (
    compile_problem_sizes,
    compile_splits,
    compile_subtree_sizes,
    precompute_masks,
)
from minterpy.global_settings import ARRAY, ARRAY_DICT  # noqa

if TYPE_CHECKING:
    # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
    from .grid import Grid
    from .multi_index import MultiIndexSet


__all__ = ["MultiIndexTree"]


[docs] class MultiIndexTree: """Base class for MultiIndexTree """ # TODO prevent dynamic attribute assignment (-> safe memory) # __slots__ = ["multi_index", "split_positions", "subtree_sizes", "stored_masks", "generating_points"]
[docs] def __init__(self, grid: Grid): multi_index = grid.multi_index if not multi_index.is_downward_closed: raise ValueError( "trying to use the divided difference scheme " "(multi-index tree) with non-downward-closed multi-indices, " "but DDS only works for downward-closed multi-indices " "(without 'holes')." ) self.grid = grid exponents = multi_index.exponents nr_exponents, spatial_dimension = exponents.shape # NOTE: the tree structure ("splitting") depends on the exponents # in each dimension of the sorted multi index array # pre-compute and store where the splits appear in the exponent array # this implicitly defines the "nodes" of the tree # TODO compute on demand? NOTE: tree is being constructed only on demand (DDS) # TODO reverse the dim order of all # (NOTE: then the "dim_idx" will then be counter intuitive: 0 for highest dimension...) self.split_positions = compile_splits(exponents) # also store the size of all nodes = how many exponent entries belong to this split # in combination with the positions of all appearing splits # the sizes fully determine the structure of the multi index tree # (position and amount of children etc.) self.subtree_sizes = compile_subtree_sizes(nr_exponents, self.split_positions) self.problem_sizes = compile_problem_sizes(self.subtree_sizes) # TODO improvement: also "pre-compute" more of the recursion through the tree, # avoid computing the node indices each time self._stored_masks: ARRAY_DICT | None = None
@property def multi_index(self) -> MultiIndexSet: """Returns the multi index set of the grid used to construct the tree. :return: the multi index set of the grid """ return self.grid.multi_index @property def stored_masks(self) -> ARRAY_DICT: """Returns the stored masks of the tree. :return: correspondencies between the left and right nodes of the tree """ # the intermediary results required for DDS # TODO remove when regular DDS functionality is no longer required (together with the dds module) if self._stored_masks is None: # lazy evaluation # based on the splittings one can compute all required correspondences # between nodes in the left and the right of the tree # this mapping can then be used to compute the nD DDS efficiently exponents = self.multi_index.exponents self._stored_masks = precompute_masks( self.split_positions, self.subtree_sizes, exponents ) return self._stored_masks