Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions docs/source/pitfall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,45 @@ in the returned expression.
Replacing ``auto tmp`` with ``xt::xarray<double> tmp`` does not change anything, ``tmp``
is still an lvalue and thus captured by reference.

.. warning::

This issue is particularly subtle with reducer functions like :cpp:func:`xt::amax`,
:cpp:func:`xt::sum`, etc. Consider the following function:

.. code::

template <typename T>
xt::xtensor<T, 2> logSoftmax(const xt::xtensor<T, 2> &matrix)
{
xt::xtensor<T, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
auto shifted = matrix - maxVals;
auto expVals = xt::exp(shifted);
auto sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return shifted - xt::log(sumExp);
}

This function may produce incorrect results or crash, especially in optimized builds.
The issue is that ``shifted``, ``expVals``, and ``sumExp`` are all lazy expressions
that hold references to local variables. When the function returns, these local
variables are destroyed, and the returned expression contains dangling references.

The fix is to evaluate reducer results and the returned expression explicitly.
Element-wise lazy expressions (like ``shifted`` and ``expVals``) are safe to
leave as ``auto``, but reducer results (like ``sumExp``) must be materialized
before being used in a subsequent element-wise expression:

.. code::

template <typename T>
xt::xtensor<T, 2> logSoftmax(const xt::xtensor<T, 2> &matrix)
{
xt::xtensor<T, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
auto shifted = matrix - maxVals;
auto expVals = xt::exp(shifted);
xt::xtensor<T, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return xt::xtensor<T, 2>(shifted - xt::log(sumExp));
}

Random numbers not consistent
-----------------------------

Expand Down
47 changes: 47 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,4 +969,51 @@ namespace xt
EXPECT_TRUE(xt::allclose(expected, unwrapped));
}
}

// Test for GitHub issue #2871: Proper handling of intermediate results
// This test documents the correct way to use reducers with keep_dims
// when intermediate expressions are needed.
TEST(xmath, issue_2871_intermediate_result_handling)
{
// This test verifies the correct pattern for using reducers with
// intermediate results. Returning a lazy expression from a function can lead
// to dangling references — only the returned expression must be evaluated.

// The CORRECT way: reducer results must be evaluated; element-wise lazy
// expressions are safe to leave as auto
auto logSoftmax_correct = [](const xt::xtensor<double, 2>& matrix)
{
xt::xtensor<double, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
auto shifted = matrix - maxVals;
auto expVals = xt::exp(shifted);
xt::xtensor<double, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
};

// Alternative CORRECT way: use xt::eval for reducer results
auto logSoftmax_eval = [](const xt::xtensor<double, 2>& matrix)
{
auto maxVals = xt::eval(xt::amax(matrix, {1}, xt::keep_dims));
auto shifted = matrix - maxVals;
auto expVals = xt::exp(shifted);
auto sumExp = xt::eval(xt::sum(expVals, {1}, xt::keep_dims));
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
};

// Test data
xt::xtensor<double, 2> input = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};

// Both implementations should produce the same result
auto result1 = logSoftmax_correct(input);
auto result2 = logSoftmax_eval(input);

EXPECT_TRUE(xt::allclose(result1, result2));

// Verify the result is a valid log-softmax (rows sum to 0 in log space)
// exp(log_softmax).sum(axis=1) should equal 1
auto exp_result = xt::exp(result1);
auto row_sums = xt::sum(exp_result, {1});
xt::xtensor<double, 1> expected_sums = {1.0, 1.0};
EXPECT_TRUE(xt::allclose(row_sums, expected_sums));
}
}
Loading