4.25. The Stiefel-Manifold Indicator¶
This example shows how to encode orthonormal-column structure with an indicator UDF. The model is
where
This is the set of matrices with orthonormal columns, which is nonconvex. The solver acts as a practical local method; the result should be interpreted as a locally optimal solution or stationary point. When \(m = n\), the Stiefel manifold reduces to the orthogonal group \(\mathcal{O}_n = \{X \in \mathbb{R}^{n \times n} : X^\top X = I\}\), so the square orthogonal-matrix case is already covered as a special case of this example.
The value returned by UDFBase.eval() is the indicator of the manifold:
So eval checks whether the columns are orthonormal:
def eval(self, arglist):
X = np.asarray(arglist[0], dtype=float)
identity = np.eye(X.shape[1])
return 0.0 if np.linalg.norm(X.T @ X - identity) <= 1e-9 else float("inf")
The proximal operator returned by UDFBase.argmin() is given by the polar factor: if
\(Z = U \Sigma V^\top\), then
That is exactly what the code computes:
def argmin(self, lamb, arglist):
Z = np.asarray(arglist[0], dtype=float)
u, _, vt = np.linalg.svd(Z, full_matrices=False)
prox = u @ vt
return [prox.tolist()]
The UDFBase.arguments() method again just binds the UDF to one symbolic matrix:
def arguments(self):
return [self.arg]
Complete runnable example:
import numpy as np
import admm
class StiefelIndicator(admm.UDFBase):
def __init__(self, arg):
self.arg = arg
def arguments(self):
return [self.arg]
def eval(self, arglist):
X = np.asarray(arglist[0], dtype=float)
identity = np.eye(X.shape[1])
return 0.0 if np.linalg.norm(X.T @ X - identity) <= 1e-9 else float("inf")
def argmin(self, lamb, arglist):
Z = np.asarray(arglist[0], dtype=float)
u, _, vt = np.linalg.svd(Z, full_matrices=False)
prox = u @ vt
return [prox.tolist()]
Y = np.array([[2.0, 0.0], [0.0, 0.5], [0.0, 0.0]])
model = admm.Model()
X = admm.Var("X", 3, 2)
model.setObjective(0.5 * admm.sum(admm.square(X - Y)) + StiefelIndicator(X))
model.optimize()
print(" * X: ", np.asarray(X.X)) # Expected: ≈ [[1, 0], [0, 1], [0, 0]]
print(" * model.ObjVal: ", model.ObjVal) # Expected: ≈ 0.624999
This example is available as a standalone script in the examples/ folder of the ADMM repository:
python examples/udf_stiefel.py
In this concrete example, the data matrix already points mostly along the first two coordinate directions, so the polar-factor projection returns the obvious orthonormal-column matrix
It is easy to verify that this point lies on the Stiefel manifold:
The polar factor of the SVD produces the nearest matrix with orthonormal columns.