Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational

from devito.logger import warning
from devito.tools import Tag, as_tuple
from devito.types.dimension import StencilDimension

Expand Down Expand Up @@ -260,6 +261,18 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
-------
An IndexSet, representing an ordered list of indices.
"""
# Check size of input weights
if nweights > 0:
do, dw = order + 1 + order % 2, nweights
if do < dw:
raise ValueError(f"More weights ({nweights}) provided than the maximum"
f"stencil size ({order + 1}) for order {order} scheme")
elif do > dw:
warning(f"Less weights ({nweights}) provided than the stencil size"
f"({order + 1}) for order {order} scheme."
" Reducing order to {nweights//2}")
order = nweights - nweights % 2

# Evaluation point
x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim]))

Expand All @@ -276,23 +289,15 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
side = side or centered

# Indices range
o_min = int(np.ceil(mid - order/2)) + side.val
o_max = int(np.floor(mid + order/2)) + side.val
r = (nweights or order) / 2
o_min = int(np.ceil(mid - r)) + side.val
o_max = int(np.floor(mid + r)) + side.val
if o_max == o_min:
if dim.is_Time or not expr.is_Staggered:
o_max += 1
else:
o_min -= 1

if nweights > 0 and (o_max - o_min + 1) != nweights:
# We cannot infer how the stencil should be centered
# if nweights is more than one extra point.
assert nweights == (o_max - o_min + 1) + 1
# In the "one extra" case we need to pad with one point to symmetrize
if (o_max - mid) > (mid - o_min):
o_min -= 1
else:
o_max += 1
# StencilDimension and expression
d = make_stencil_dimension(expr, o_min, o_max)
iexpr = expr.indices_ref[dim] + d * dim.spacing
Expand Down
31 changes: 30 additions & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from conftest import assert_structure, get_params, get_arrays, check_array
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
cos, sin)
Coefficient, Substitutions, cos, sin)
from devito.finite_differences import Weights
from devito.arch.compiler import OneapiCompiler
from devito.ir import Expression, FindNodes, FindSymbols
Expand Down Expand Up @@ -76,6 +76,35 @@ def test_multiple_cross_derivs(self, coeffs, expected):
weights = {f for f in functions if isinstance(f, Weights)}
assert len(weights) == expected

@pytest.mark.parametrize('order', [1, 2])
@pytest.mark.parametrize('nweight', [None, +4, -4])
def test_legacy_api(self, order, nweight):
grid = Grid(shape=(51, 51, 51))
x, y, z = grid.dimensions

nweight = 0 if nweight is None else nweight
so = 8

u = TimeFunction(name='u', grid=grid, space_order=so,
coefficients='symbolic')

w0 = np.arange(so + 1 + nweight) + 1
wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}'
wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))'

coeffs_x_p1 = Coefficient(order, u, x, w0)

coeffs = Substitutions(coeffs_x_p1)

eqn = Eq(u, u.dx.dy + u.dx2 + .37, coefficients=coeffs)

if nweight > 0:
with pytest.raises(ValueError):
op = Operator(eqn, opt=('advanced', {'expand': False}))
else:
op = Operator(eqn, opt=('advanced', {'expand': False}))
assert f'{wdef} = {wstr}' in str(op)


class Test1Pass:

Expand Down
Loading