Skip to content

Support using Expressions in numpy matrix operations #3636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyomo/core/base/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
import pyomo.core.expr.numeric_expr as numeric_expr
from pyomo.core.base.component import ComponentData, ModelComponentFactory
from pyomo.core.base.global_set import UnindexedComponent_index
from pyomo.core.base.indexed_component import IndexedComponent, UnindexedComponent_set
from pyomo.core.base.indexed_component import (
IndexedComponent,
UnindexedComponent_set,
IndexedComponent_NDArrayMixin,
)
from pyomo.core.expr.numvalue import as_numeric
from pyomo.core.base.initializer import Initializer

Expand Down Expand Up @@ -235,7 +239,7 @@ class _GeneralExpressionData(metaclass=RenamedClass):
@ModelComponentFactory.register(
"Named expressions that can be used in other expressions."
)
class Expression(IndexedComponent):
class Expression(IndexedComponent, IndexedComponent_NDArrayMixin):
"""A shared expression container, which may be defined over an index.

Parameters
Expand Down
19 changes: 17 additions & 2 deletions pyomo/core/tests/unit/test_expr_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pyomo.common.unittest as unittest

from pyomo.common.dependencies import numpy as np, numpy_available
from pyomo.environ import ConcreteModel, Var, Constraint
from pyomo.environ import ConcreteModel, Var, Constraint, Param, Expression


@unittest.skipUnless(numpy_available, "tests require numpy")
Expand All @@ -37,7 +37,7 @@ def test_scalar_operations(self):
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])

def test_vector_operations(self):
def test_variable_vector_operations(self):
m = ConcreteModel()
m.x = Var()
m.y = Var([0, 1, 2])
Expand Down Expand Up @@ -90,6 +90,21 @@ def test_vector_operations(self):
m.x * m.y * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
)

def test_expression_vector_operations(self):
m = ConcreteModel()
m.p = Param(range(3), range(2), initialize=lambda m, i, j: 10 * i + j)

m.e = Expression(range(3))
m.f = Expression(range(2))

expr = np.transpose(m.e) @ m.p @ m.f
print(expr)
self.assertExpressionsEqual(
expr,
(m.e[0] * 0 + m.e[1] * 10 + m.e[2] * 20) * m.f[0]
+ (m.e[0] * 1 + m.e[1] * 11 + m.e[2] * 21) * m.f[1],
)


if __name__ == "__main__":
unittest.main()
Loading