Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def bounds(self, _min=None, _max=None):

return (_min, _max)

@property
def start(self):
"""The start value."""
if self.direction is Forward:
return self.dim.symbolic_min
else:
return self.dim.symbolic_max

@property
def step(self):
"""The step value."""
Expand Down
7 changes: 2 additions & 5 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -886,7 +885,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate', 'groupby')
assert mode in (None, 'immediate')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -903,9 +902,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class IndexMode(Tag):
REGULAR = IndexMode('regular')
IRREGULAR = IndexMode('irregular')

# Symbols to create mock data depdendencies
# Symbols to create mock data dependencies
mocksym0 = Symbol(name='__⋈_0__')
mocksym1 = Symbol(name='__⋈_1__')

Expand Down
36 changes: 27 additions & 9 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.ir.support import Forward, Scope
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted)
frozendict, is_integer, filter_sorted, EnrichedTuple)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand All @@ -28,7 +28,22 @@ class HaloLabel(Tag):
STENCIL = HaloLabel('stencil')


HaloSchemeEntry = namedtuple('HaloSchemeEntry', 'loc_indices loc_dirs halos dims')
class HaloSchemeEntry(EnrichedTuple):

__rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')

def __init__(self, loc_indices, loc_dirs, halos, dims, getters=None):
self.loc_indices = frozendict(loc_indices)
self.loc_dirs = frozendict(loc_dirs)
self.halos = frozenset(halos)
self.dims = frozenset(dims)

def __hash__(self):
return hash((self.loc_indices,
self.loc_dirs,
self.halos,
self.dims))


Halo = namedtuple('Halo', 'dim side')

Expand Down Expand Up @@ -121,7 +136,10 @@ def union(self, halo_schemes):
Create a new HaloScheme from the union of a set of HaloSchemes.
"""
halo_schemes = [hs for hs in halo_schemes if hs is not None]
if not halo_schemes:

if len(halo_schemes) == 1:
return halo_schemes[0]
elif not halo_schemes:
return None

fmapper = {}
Expand Down Expand Up @@ -365,6 +383,10 @@ def distributed_aindices(self):
def loc_indices(self):
return set().union(*[i.loc_indices.keys() for i in self.fmapper.values()])

@cached_property
def loc_values(self):
return set().union(*[i.loc_indices.values() for i in self.fmapper.values()])

@cached_property
def arguments(self):
return self.dimensions | set(flatten(self.honored.values()))
Expand Down Expand Up @@ -503,8 +525,6 @@ def classify(exprs, ispace):

loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
ispace.directions)
halos = frozenset(halos)
dims = frozenset(dims)

mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)

Expand Down Expand Up @@ -556,7 +576,7 @@ def process_loc_indices(raw_loc_indices, directions):
known = set().union(*[i._defines for i in loc_indices])
loc_dirs = {d: v for d, v in directions.items() if d in known}

return frozendict(loc_indices), frozendict(loc_dirs)
return loc_indices, loc_dirs


class HaloTouch(sympy.Function, Reconstructable):
Expand Down Expand Up @@ -634,9 +654,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
# Nope, let's try with the next Indexed, if any
continue

hse = HaloSchemeEntry(frozendict(loc_indices),
frozendict(loc_dirs),
hse0.halos, hse0.dims)
hse = hse0._rebuild(loc_indices=loc_indices, loc_dirs=loc_dirs)

else:
continue
Expand Down
Loading
Loading