Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ public double proposal() {
return -Math.log(scale);
} else {
// scale the beast.tree
final int internalNodes = tree.scale(scale);
return Math.log(scale) * (internalNodes - 2);
// tree.scale returns the log Jacobian (dof * log(scale));
// operator adds the -2*log(scale) kernel-symmetry correction.
final double treeLogJacobian = tree.scale(scale);
return treeLogJacobian - 2 * Math.log(scale);
}
}

Expand Down Expand Up @@ -205,9 +207,13 @@ public double proposal() {
// for the proof. It is supposed to be somewhere in an Alexei/Nicholes article.

// all Values assumed independent!
final int computedDoF = param.scale(scale);
final int usedDoF = (specifiedDoF > 0) ? specifiedDoF : computedDoF ;
hastingsRatio = (usedDoF - 2) * Math.log(scale);
// param.scale returns the log Jacobian (dof * log(scale));
// operator adds the -2*log(scale) kernel-symmetry correction.
final double paramLogJacobian = param.scale(scale);
final double dofLogScale = (specifiedDoF > 0)
? specifiedDoF * Math.log(scale)
: paramLogJacobian;
hastingsRatio = dofLogScale - 2 * Math.log(scale);
} else {
hastingsRatio = -Math.log(scale);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,9 @@ public int setValue(final int param, final double value) throws Exception {
}
return 1;
} else if (para instanceof Tree) {
double old = para.getArrayValue();
double scale = value / old;
((Tree) para).scale(scale);
// Use the Scalable contract: setScalableValue lands the tree's
// dilation-axis summary (sum of intervals) at exactly `value`.
((Tree) para).setScalableValue(value);
return ((Tree) para).getInternalNodeCount();
}
return 0;
Expand All @@ -765,7 +765,8 @@ public double getValue(final int param) {
if (f instanceof RealParameter) {
return f.getArrayValue(getX(param));
}
return ((Tree) f).getRoot().getHeight();
// Read the tree's position on its dilation axis (sum of intervals)
return ((Tree) f).getScalableValue();
}

// public double getLower(final int param) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ public double proposal() {
return Math.log(scale);
} else {
// scale the beast.tree
// tree.scale returns the log Jacobian (dof * log(scale));
// Bactrian kernel is symmetric so no kernel-ratio correction is needed.
final double scale = getScaler(0, Double.NaN);
final int scaledNodes = tree.scale(scale);
return Math.log(scale) * scaledNodes;
return tree.scale(scale);
}
}

Expand Down Expand Up @@ -125,10 +126,11 @@ public double proposal() {
// for the proof. It is supposed to be somewhere in an Alexei/Nicholes article.

// all Values assumed independent!
final double scale = getScaler(0, param.getValue(0));
final int computedDoF = param.scale(scale);
final int usedDoF = (specifiedDoF > 0) ? specifiedDoF : computedDoF ;
hastingsRatio = usedDoF * Math.log(scale);
// param.scale returns the log Jacobian (dof * log(scale));
// Bactrian kernel is symmetric so no kernel-ratio correction is needed.
final double scale = getScaler(0, param.getValue(0));
final double paramLogJacobian = param.scale(scale);
hastingsRatio = (specifiedDoF > 0) ? specifiedDoF * Math.log(scale) : paramLogJacobian;
} else {

// which position to scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ private void fullInit() {

final ClusterTree ctree = new ClusterTree();
ctree.initByName("initial", gtree, "clusterType", "upgma", "taxa", alignment);
gtree.scale(1 / mu);
// Affine helper: scales internal heights by 1/mu (preserving leaf heights).
// Tree.scale would do interval scaling, which is the wrong operation here.
gtree.scaleToRootHeight(gtree.getRoot().getHeight() / mu);

maxNsites = max(maxNsites, alignment.getSiteCount());
}
Expand Down Expand Up @@ -388,7 +390,8 @@ private void randomInit() {
s += 1.0/k;
}
final double rootHeight = (1/lam) * s;
stree.scale(rootHeight/stree.getRoot().getHeight());
// Affine helper: lands species tree root at rootHeight exactly.
stree.scaleToRootHeight(rootHeight);
randomInitGeneTrees(rootHeight);
// final List<Tree> geneTrees = genes.get();
// for (final Tree gtree : geneTrees) {
Expand Down
50 changes: 49 additions & 1 deletion beast-base/src/main/java/beast/base/evolution/tree/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,17 @@ public Set<String> getLengthMetaDataNames() {


/**
* scale height of this node and all its descendants
* Affine height scaling: multiply this node's height by {@code scale},
* recursively. Leaves and fake (sampled-ancestor) nodes are skipped.
* <p>
* Used by {@link Tree#scaleToRootHeight(double)} for the affine
* "land root at target" helper. NOT used by {@link Tree#scale(double)},
* which applies interval scaling via {@link #intervalScale(double)}.
* <p>
* Throws {@link IllegalArgumentException} for non-ultrametric trees if a
* scaled internal height drops below a leaf child's height. This is the
* historical {@code Node.scale} behaviour, preserved here for callers
* that explicitly want affine scaling and can handle the failure.
*
* @param scale scale factor
* @return degrees of freedom scaled (used for HR calculations)
Expand Down Expand Up @@ -832,6 +842,44 @@ public int scale(final double scale) {
return dof;
}

/**
* Interval scaling: multiply this node's "margin" (height above its taller
* child) by {@code scale}, recursively. Tip heights are preserved by
* construction, so the resulting tree is always valid for any positive
* scale factor.
* <p>
* Used by {@link Tree#scale(double)} as the contract-bound dilation
* operation. Each margin is independently multiplied by {@code scale}, so
* the tree's sum of margins (= {@link Tree#getScalableValue()}) is
* exactly multiplied by {@code scale}.
*
* @param scale positive scale factor
* @return number of intervals (margins) scaled, for HR calculations
*/
public int intervalScale(final double scale) {
if (isLeaf()) {
return 0;
}
// sampled-ancestor fake nodes: skip, recurse into the non-direct-ancestor child
if (isFake()) {
if (getLeft().isDirectAncestor()) {
return getRight().intervalScale(scale);
} else {
return getLeft().intervalScale(scale);
}
}
startEditing();
final double oldMargin = height - Math.max(getLeft().getHeight(), getRight().getHeight());
int scaledNodeCount = 1;
scaledNodeCount += getLeft().intervalScale(scale);
scaledNodeCount += getRight().intervalScale(scale);
// recompute minHeight after children have been scaled
final double minChildHeight = Math.max(getLeft().getHeight(), getRight().getHeight());
height = oldMargin * scale + minChildHeight;
isDirty |= Tree.IS_DIRTY;
return scaledNodeCount;
}

// /**
// * Used for sampled ancestor trees
// * Scales this node and all its descendants (either all descendants, or only non-sampled descendants)
Expand Down
85 changes: 70 additions & 15 deletions beast-base/src/main/java/beast/base/evolution/tree/Tree.java
Original file line number Diff line number Diff line change
Expand Up @@ -654,27 +654,82 @@ public void setEverythingDirty(final boolean isDirty) {
}
}

/**
* Scale this tree by factor {@code scale} along its dilation axis.
* <p>
* The dilation axis for {@code Tree} is the sum of intervals (margins above
* taller children) across internal non-fake nodes. Each margin is
* multiplied by {@code scale}; tip dates are preserved by construction.
* Equivalent to interval scaling: see {@link #getScalableValue()} for the
* summary that scales by exactly {@code scale} under this operation.
* <p>
* Always succeeds for any positive {@code scale} (no leaf can violate a
* branch-length constraint, because each margin remains positive).
*
* @return log Jacobian determinant of the move ({@code dof × log(scale)})
*/
@Override
public int scale(final double scale) {
return root.scale(scale);
public double scale(final double scale) {
int dof = root.intervalScale(scale);
return dof * Math.log(scale);
}

/**
* Read this tree's position on its dilation axis: the sum of intervals
* (margins above taller children) across internal non-fake nodes.
* <p>
* This summary is exactly {@code s}-equivariant under {@link #scale(double)}.
*/
@Override
public void scaleOne(int i, final double scale) {
startEditing(null);
double h = m_nodes[i].getHeight();
double newHeight = h * scale;
for (Node child : m_nodes[i].children) {
if (newHeight < child.getHeight()) {
throw new IllegalArgumentException("scale sets nodes below child result in negative branch length");
}
}
if (!m_nodes[i].isRoot() && newHeight > m_nodes[i].getParent().getHeight()) {
throw new IllegalArgumentException("scale sets nodes above parent result in negative branch length");
}
m_nodes[i].setHeight(newHeight);
public double getScalableValue() {
return computeSumIntervals(root);
}

/**
* Affine helper: scale the tree so its root height equals {@code targetRootHeight}.
* Multiplies every internal non-fake non-leaf height by
* {@code targetRootHeight / oldRootHeight}, keeping leaf heights fixed.
* <p>
* Not part of the Scalable contract. May throw
* {@link IllegalArgumentException} for non-ultrametric trees if the
* resulting state has a parent below a leaf child.
*
* @return number of internal nodes scaled (degrees of freedom)
*/
public int scaleToRootHeight(final double targetRootHeight) {
double oldRoot = root.getHeight();
if (oldRoot <= 0.0) {
throw new IllegalArgumentException(
"Cannot scale to root height: current root height is " + oldRoot);
}
return root.scale(targetRootHeight / oldRoot);
}

/**
* Compute the sum of margins (h_N - max(h_children)) across internal
* non-fake nodes. For sampled-ancestor trees, fake nodes are skipped.
*/
private double computeSumIntervals(final Node node) {
if (node.isLeaf()) {
return 0.0;
}
if (node.isFake()) {
// skip the fake; recurse into the non-direct-ancestor child
if (node.getLeft().isDirectAncestor()) {
return computeSumIntervals(node.getRight());
} else {
return computeSumIntervals(node.getLeft());
}
}
double margin = node.getHeight()
- Math.max(node.getLeft().getHeight(), node.getRight().getHeight());
double sum = margin;
sum += computeSumIntervals(node.getLeft());
if (node.getRight() != null) {
sum += computeSumIntervals(node.getRight());
}
return sum;
}

// /**
// * The same as scale but with option to scale all sampled nodes
Expand Down
117 changes: 94 additions & 23 deletions beast-base/src/main/java/beast/base/inference/Scalable.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,107 @@

import beast.base.core.Description;

@Description("For StateNodes that can be scaled by a scale/up-down operator")
/**
* A {@code Scalable} represents a state component that can be moved along a
* single positive-real dilation axis. The interface defines three operations
* that must be mutually consistent &mdash; together they form the
* <em>Scalable contract</em>:
*
* <ol>
* <li>{@link #scale(double)} dilates the component by a factor {@code s}
* and returns the log Jacobian determinant of that move
* (i.e. {@code log |det(∂new/∂old)|}).</li>
* <li>{@link #getScalableValue()} reads the component's current position
* on its dilation axis.</li>
* <li>{@link #setScalableValue(double)} moves the component so that
* {@code getScalableValue()} returns the supplied target {@code V}, and
* returns the log Jacobian for that move.</li>
* </ol>
*
* <p>The log Jacobian is the move's contribution to the Metropolis-Hastings
* acceptance ratio from the change-of-variables formula. The proposal density
* ratio (the "Hastings ratio" proper, {@code q(reverse)/q(forward)}) lives in
* the calling operator's kernel and is <em>not</em> part of the Scalable's
* return.</p>
*
* <p>The contract requires the following three invariants to hold for any
* valid {@code Scalable x} and any positive scale factor {@code s}:</p>
*
* <pre>
* // (1) scale-equivariance
* double v0 = x.getScalableValue();
* x.scale(s);
* assert x.getScalableValue() == s * v0;
*
* // (2) set is a fixed point of get
* x.setScalableValue(V);
* assert x.getScalableValue() == V;
*
* // (3) set composes with scale
* // x.setScalableValue(x.getScalableValue() * s)
* // produces the same state as
* // x.scale(s)
* </pre>
*
* <p>The choice of dilation axis (and therefore the meaning of
* {@code getScalableValue}) is bound to the implementation of {@code scale}.
* For example, an affine-scaling parameter exposes its value directly. A tree
* whose {@code scale} is interval-scaling exposes its sum-of-margins. A custom
* {@code Scalable} chooses whichever summary is exactly multiplied by {@code s}
* under its own {@code scale} operation.</p>
*
* <p>{@code scale(s)} is expected to succeed for any positive {@code s} that
* leaves the component in a valid state. Implementations may throw
* {@link IllegalArgumentException} for moves that produce an invalid state;
* such throws act as rejection signals for the calling operator. The contract
* invariants apply when {@code scale} does not throw.</p>
*
* @see <a href="https://github.com/CompEvol/beast3/issues/20">beast3 issue #20</a>
*/
@Description("State component that can be dilated along a 1-D axis by scale or up-down operators.")
public interface Scalable {

/**
* Scale StateNode with amount scale and
* Dilate this component by factor {@code s} along its scaling axis.
* After the call, {@link #getScalableValue()} returns
* {@code s * (its previous value)}.
*
* @param scale scaling factor
* @return the number of degrees of freedom used in this operation. This number varies
* for the different types of StateNodes. For example, for real
* valued n-dimensional parameters, it is n, for a tree it is the
* number of internal nodes being scaled.
* @throws IllegalArgumentException when StateNode become not valid, e.g. has
* values outside bounds or negative branch lengths.
* @param s positive scale factor
* @return log Jacobian determinant of this move
* @throws IllegalArgumentException if the move would produce an invalid state
*/
abstract public int scale(double scale);
double scale(double s);

/**
* only scale the i-th element of the StateNode
* @param i
* @param scale
* Read the component's current position on its dilation axis.
* The contract requires this to be exactly {@code s}-equivariant under
* {@link #scale(double)}.
*/
abstract public void scaleOne(int i, double scale);
double getScalableValue();

/**
* Move the component so that {@link #getScalableValue()} returns {@code V}.
* Defined as {@code scale(V / getScalableValue())}.
* <p>
* Implementations rarely need to override this; the default expresses the
* contract identity {@code setScalableValue(V) ≡ scale(V / getScalableValue())}
* directly. Override only if the dilation axis cannot be reached by a
* single multiplicative scale (rare).
*
* @param V target value (must be positive for typical multiplicative axes)
* @return log Jacobian determinant of this move
* @throws IllegalArgumentException if the current value is zero (no
* multiplicative scale can land at {@code V}) or if the resulting
* state would be invalid
*/
default double setScalableValue(double V) {
double current = getScalableValue();
if (current == 0.0) {
throw new IllegalArgumentException(
"Cannot set scalable value: current value is zero "
+ "(no multiplicative scale lands at " + V + ")");
}
return scale(V / current);
}

default double scaleAll(double scale) {
try {
int d = scale(scale);
return d * Math.log(scale);
} catch (IllegalArgumentException e) {
return Double.NEGATIVE_INFINITY;
}
}

}
Loading
Loading