diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index d153d082bd..5856a767ef 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -4,6 +4,7 @@ import numpy as np from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational +from devito.logger import warning from devito.tools import Tag, as_tuple from devito.types.dimension import StencilDimension @@ -260,6 +261,18 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights ------- An IndexSet, representing an ordered list of indices. """ + # Check size of input weights + if nweights > 0: + do, dw = order + 1 + order % 2, nweights + if do < dw: + raise ValueError(f"More weights ({nweights}) provided than the maximum" + f"stencil size ({order + 1}) for order {order} scheme") + elif do > dw: + warning(f"Less weights ({nweights}) provided than the stencil size" + f"({order + 1}) for order {order} scheme." + " Reducing order to {nweights//2}") + order = nweights - nweights % 2 + # Evaluation point x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim])) @@ -276,23 +289,15 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights side = side or centered # Indices range - o_min = int(np.ceil(mid - order/2)) + side.val - o_max = int(np.floor(mid + order/2)) + side.val + r = (nweights or order) / 2 + o_min = int(np.ceil(mid - r)) + side.val + o_max = int(np.floor(mid + r)) + side.val if o_max == o_min: if dim.is_Time or not expr.is_Staggered: o_max += 1 else: o_min -= 1 - if nweights > 0 and (o_max - o_min + 1) != nweights: - # We cannot infer how the stencil should be centered - # if nweights is more than one extra point. - assert nweights == (o_max - o_min + 1) + 1 - # In the "one extra" case we need to pad with one point to symmetrize - if (o_max - mid) > (mid - o_min): - o_min -= 1 - else: - o_max += 1 # StencilDimension and expression d = make_stencil_dimension(expr, o_min, o_max) iexpr = expr.indices_ref[dim] + d * dim.spacing diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index c82b548c29..86c61f6561 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -3,7 +3,7 @@ from conftest import assert_structure, get_params, get_arrays, check_array from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator, - cos, sin) + Coefficient, Substitutions, cos, sin) from devito.finite_differences import Weights from devito.arch.compiler import OneapiCompiler from devito.ir import Expression, FindNodes, FindSymbols @@ -76,6 +76,35 @@ def test_multiple_cross_derivs(self, coeffs, expected): weights = {f for f in functions if isinstance(f, Weights)} assert len(weights) == expected + @pytest.mark.parametrize('order', [1, 2]) + @pytest.mark.parametrize('nweight', [None, +4, -4]) + def test_legacy_api(self, order, nweight): + grid = Grid(shape=(51, 51, 51)) + x, y, z = grid.dimensions + + nweight = 0 if nweight is None else nweight + so = 8 + + u = TimeFunction(name='u', grid=grid, space_order=so, + coefficients='symbolic') + + w0 = np.arange(so + 1 + nweight) + 1 + wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}' + wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))' + + coeffs_x_p1 = Coefficient(order, u, x, w0) + + coeffs = Substitutions(coeffs_x_p1) + + eqn = Eq(u, u.dx.dy + u.dx2 + .37, coefficients=coeffs) + + if nweight > 0: + with pytest.raises(ValueError): + op = Operator(eqn, opt=('advanced', {'expand': False})) + else: + op = Operator(eqn, opt=('advanced', {'expand': False})) + assert f'{wdef} = {wstr}' in str(op) + class Test1Pass: