Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package liquidjava.rj_language.opt;

import java.util.Optional;

import liquidjava.processor.SimplifiedVCImplication;
import liquidjava.processor.VCImplication;
import liquidjava.rj_language.Predicate;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.GroupExpression;

/**
* Simplifies VCImplication chains by propagating exact function invocation equalities
*/
public class VCFunctionSubstitution implements VCSimplificationPass {

/**
* A substitution discovered from a function invocation equality
*/
private record Substitution(VCImplication node, FunctionInvocation invocation, Expression replacement) {
}

/**
* Applies one function invocation substitution in a VC chain
*/
@Override
public VCImplication apply(VCImplication implication) {
VCImplication result = implication.clone();
Optional<Substitution> substitutionOpt = findSubstitution(result);

if (substitutionOpt.isPresent()) {
Substitution substitution = substitutionOpt.get();
result = substitute(result, substitution.node(), substitution.invocation(), substitution.replacement());
}
return result;
}

/**
* Preserves nodes before the source equality and starts rewriting at the source suffix
*/
private VCImplication substitute(VCImplication implication, VCImplication node, FunctionInvocation invocation,
Expression replacement) {
if (implication == null)
return null;

// skip the source node to remove it from the chain and start substitution from the next node
if (implication == node) {
VCImplication result = implication.copyWithRefinement(implication.getRefinement().clone());
result.setNext(substituteSuffix(implication.getNext(), node, invocation, replacement));
return result;
}

// preserve the current node and continue rewriting the suffix
VCImplication result = implication.copyWithRefinement(implication.getRefinement().clone());
result.setNext(substitute(implication.getNext(), node, invocation, replacement));
return result;
}

/**
* Rewrites every node after the source equality with one function substitution
*/
private VCImplication substituteSuffix(VCImplication implication, VCImplication source,
FunctionInvocation invocation, Expression replacement) {
if (implication == null)
return null;

VCImplication result = substituteNode(implication, source, invocation, replacement);
result.setNext(substituteSuffix(implication.getNext(), source, invocation, replacement));
return result;
}

/**
* Substitutes one exact function invocation inside one VC node while preserving simplification metadata
*/
private VCImplication substituteNode(VCImplication implication, VCImplication source, FunctionInvocation invocation,
Expression replacement) {
Expression expression = implication.getRefinement().getExpression().clone();
if (!containsExpression(expression, invocation))
return implication.copyWithRefinement(new Predicate(expression));

Expression substituted = expression.substitute(invocation, replacement.clone());
return new SimplifiedVCImplication(implication, new Predicate(substituted), source);
}

/**
* Finds the first function substitution candidate that is used in the remaining suffix
*/
private Optional<Substitution> findSubstitution(VCImplication implication) {
if (implication == null)
return Optional.empty();

Optional<Substitution> current = getSubstitution(implication);
if (current.isPresent() && containsExpression(implication.getNext(), current.get().invocation()))
return current;

return findSubstitution(implication.getNext());
}

/**
* Extracts a substitution from one VC node refinement
*/
private Optional<Substitution> getSubstitution(VCImplication implication) {
return getSubstitution(implication, implication.getRefinement().getExpression().clone());
}

/**
* Extracts a substitution from a top-level equality or conjunction
*/
private Optional<Substitution> getSubstitution(VCImplication implication, Expression expression) {
if (expression instanceof GroupExpression group)
return getSubstitution(implication, group.getExpression());

if (expression instanceof BinaryExpression binary && "&&".equals(binary.getOperator())) {
Optional<Substitution> left = getSubstitution(implication, binary.getFirstOperand());
if (left.isPresent())
return left;
return getSubstitution(implication, binary.getSecondOperand());
}

if (!(expression instanceof BinaryExpression binary) || !"==".equals(binary.getOperator()))
return Optional.empty();

Expression left = binary.getFirstOperand();
Expression right = binary.getSecondOperand();
if (left instanceof FunctionInvocation invocation && !containsExpression(right, left))
return Optional.of(new Substitution(implication, (FunctionInvocation) invocation.clone(), right.clone()));
if (right instanceof FunctionInvocation invocation && !containsExpression(left, right))
return Optional.of(new Substitution(implication, (FunctionInvocation) invocation.clone(), left.clone()));

return Optional.empty();
}

/**
* Checks whether an expression contains another expression
*/
private boolean containsExpression(Expression expression, Expression target) {
if (expression.equals(target))
return true;

for (Expression child : expression.getChildren())
if (containsExpression(child, target))
return true;
return false;
}

/**
* Checks whether a VC suffix contains an expression
*/
private boolean containsExpression(VCImplication implication, Expression target) {
for (VCImplication current = implication; current != null; current = current.getNext())
if (containsExpression(current.getRefinement().getExpression(), target))
return true;
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
*/
public class VCSimplification {

private static final List<VCSimplificationPass> PASSES = List.of(new VCSubstitution(), new VCBinderSimplification(),
new VCFolding(), new VCArithmeticSimplification(), new VCLogicalSimplification());
private static final List<VCSimplificationPass> PASSES = List.of(new VCSubstitution(), new VCFunctionSubstitution(),
new VCBinderSimplification(), new VCFolding(), new VCArithmeticSimplification(),
new VCLogicalSimplification());

/**
* Applies all available simplification steps to a VC chain until a fixed point is reached
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package liquidjava.rj_language.opt;

import static liquidjava.utils.VCTestUtils.*;

import liquidjava.processor.VCImplication;
import org.junit.jupiter.api.Test;

class VCFunctionSubstitutionTest {

private final VCFunctionSubstitution substitution = new VCFunctionSubstitution();

@Test
void substitutesExactFunctionInvocationIntoSuffix() {
VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0")));
}

@Test
void substitutesReverseFunctionEquality() {
VCImplication implication = vc("0 == f(x)", "f(y) == f(x) + 1");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("0 == f(x)"), expect("f(y) == 0 + 1", "0 == f(x)")));
}

@Test
void preservesSourceNode() {
VCImplication implication = vc("f(x) == 0", "f(x) > -1");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("f(x) == 0"), expect("0 > -1", "f(x) == 0")));
}

@Test
void doesNotRewriteEarlierNodesFromLaterEquality() {
VCImplication implication = vc("f(x) > 0", "f(x) == 1");

assertSimplificationSteps(substitution::apply, implication, chain(expect("f(x) > 0"), expect("f(x) == 1")));
}

@Test
void skipsUsedUpEqualityAndUsesNextAvailableEquality() {
VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1", "f(y) == 1");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("f(y) == 1")),
chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"),
expect("0 + 1 == 1", "f(y) == 0 + 1")));
}

@Test
void doesNotGeneralizeAcrossDifferentArguments() {
VCImplication implication = vc("f(x) == 0", "f(y) == 0");

assertSimplificationSteps(substitution::apply, implication, chain(expect("f(x) == 0"), expect("f(y) == 0")));
}

@Test
void ignoresRecursiveFunctionEquality() {
VCImplication implication = vc("f(x) == f(x) + 1", "f(x) > 0");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("f(x) == f(x) + 1"), expect("f(x) > 0")));
}

@Test
void extractsEqualityFromTopLevelConjunction() {
VCImplication implication = vc("ok && f(x) == 0", "f(y) == f(x) + 1");

assertSimplificationSteps(substitution::apply, implication,
chain(expect("ok && f(x) == 0"), expect("f(y) == 0 + 1", "ok && f(x) == 0")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class VCImplicationGenerator extends Generator<VCImplication> {

public static final String[] BINDERS = { "x", "y", "z", "w" };
public static final String[] FREE_VARS = { "a", "b", "c", "d" };
public static final String[] FUNCTIONS = { "f", "g" };
private static final String[] COMPARISON_OPS = { "==", "!=", ">=", ">", "<=", "<" };
private static final String[] BOOLEAN_OPS = { "&&", "||", "-->", "==", "!=" };
private static final String[] ARITHMETIC_OPS = { "+", "-", "*" };
Expand All @@ -21,7 +22,7 @@ public VCImplicationGenerator() {

@Override
public VCImplication generate(SourceOfRandomness random, GenerationStatus status) {
return switch (random.nextInt(0, 14)) {
return switch (random.nextInt(0, 18)) {
case 0 -> vc(substitution(random, "x"), comparison(random, "x"));
case 1 -> vc(reverseSubstitution(random, "x"), comparison(random, "x"));
case 2 -> vc(nonSubstitution(random, "x"), substitution(random, "y"), comparison(random, "y"));
Expand All @@ -36,6 +37,11 @@ public VCImplication generate(SourceOfRandomness random, GenerationStatus status
case 11 -> vc(logicalIdentity(random));
case 12 -> vc(unusedTrueBinder(random));
case 13 -> vc(falseBinder(random));
case 14 -> exactFunctionSubstitution(random);
case 15 -> reverseFunctionSubstitution(random);
case 16 -> chainedFunctionSubstitution(random);
case 17 -> differentArgumentFunctionSubstitution(random);
case 18 -> recursiveFunctionSubstitution(random);
default -> vc(substitution(random, "x"), substitution(random, "y"), foldableComparison(random));
};
}
Expand All @@ -62,6 +68,59 @@ private static String nonSubstitution(SourceOfRandomness random, String binder)
return "∀" + binder + ":int. " + binder + " == " + binder + " " + signed(random.nextInt(1, 5));
}

private static VCImplication exactFunctionSubstitution(SourceOfRandomness random) {
String function = functionName(random);
return vc(functionSubstitution(random, function, "a"), functionUse(random, function, "a"));
}

private static VCImplication reverseFunctionSubstitution(SourceOfRandomness random) {
String function = functionName(random);
return vc(reverseFunctionSubstitution(random, function, "a"), functionUse(random, function, "a"));
}

private static VCImplication chainedFunctionSubstitution(SourceOfRandomness random) {
String function = functionName(random);
return vc(functionSubstitution(random, function, "a"), dependentFunctionSubstitution(random, function),
functionUse(random, function, "b"));
}

private static VCImplication differentArgumentFunctionSubstitution(SourceOfRandomness random) {
String function = functionName(random);
return vc(functionSubstitution(random, function, "a"), functionUse(random, function, "b"));
}

private static VCImplication recursiveFunctionSubstitution(SourceOfRandomness random) {
String function = functionName(random);
String invocation = functionInvocation(function, "a");
return vc(invocation + " == " + invocation + " " + signed(random.nextInt(1, 5)),
functionUse(random, function, "a"));
}

private static String functionSubstitution(SourceOfRandomness random, String function, String argument) {
return functionInvocation(function, argument) + " == " + value(random);
}

private static String reverseFunctionSubstitution(SourceOfRandomness random, String function, String argument) {
return value(random) + " == " + functionInvocation(function, argument);
}

private static String dependentFunctionSubstitution(SourceOfRandomness random, String function) {
return functionInvocation(function, "b") + " == " + functionInvocation(function, "a") + " "
+ signed(random.nextInt(-3, 3));
}

private static String functionUse(SourceOfRandomness random, String function, String argument) {
return functionInvocation(function, argument) + " " + comparisonOperator(random) + " " + intLiteral(random);
}

private static String functionInvocation(String function, String argument) {
return function + "(" + argument + ")";
}

private static String functionName(SourceOfRandomness random) {
return FUNCTIONS[random.nextInt(0, FUNCTIONS.length - 1)];
}

private static String[] unusedTrueBinder(SourceOfRandomness random) {
return new String[] { "∀x:int. true", comparison(random, "a") };
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.pholser.junit.quickcheck.runner.JUnitQuickcheck;
import liquidjava.processor.VCImplication;
import liquidjava.processor.context.Context;
import liquidjava.processor.context.GhostFunction;
import liquidjava.rj_language.Predicate;
import liquidjava.rj_language.ast.BinaryExpression;
import liquidjava.rj_language.ast.Expression;
Expand All @@ -19,12 +20,15 @@
import liquidjava.smt.SMTResult;
import liquidjava.utils.TestUtils;
import org.junit.runner.RunWith;
import spoon.Launcher;
import spoon.reflect.factory.Factory;

@RunWith(JUnitQuickcheck.class)
public class VCSimplificationPropertyBasedTest {

private static final int TRIALS = 50; // number of random VCs to test
private static final int MAX_STEPS = 20; // to prevent infinite loops in case of non-termination
private static final Factory FACTORY = new Launcher().getFactory();

@Property(trials = TRIALS)
public void eachSimplificationStepPreservesVcSemantics(@From(VCImplicationGenerator.class) VCImplication vc) {
Expand All @@ -47,6 +51,9 @@ private static void setUpContext() {
TestUtils.addIntVariableToContext(variable);
for (String variable : VCImplicationGenerator.FREE_VARS)
TestUtils.addIntVariableToContext(variable);
for (String function : VCImplicationGenerator.FUNCTIONS)
Context.getInstance().addGhostFunction(
new GhostFunction(function, List.of("int"), FACTORY.Type().INTEGER_PRIMITIVE, FACTORY, ""));
}

private static void assertEquivalent(VCImplication unsimplified, VCImplication simplified, int step) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ void simplifyAppliesLongSubstitutionChainBeforeReachingFixedPoint() {
chain(expect("3 == 3", "2 + 1 == 3")), chain(expect("true", "3 == 3")));
}

@Test
void simplifyPropagatesFunctionInvocationEqualitiesBeforeReachingFixedPoint() {
VCImplication implication = vc("f(x) == 0", "f(y) == f(x) + 1", "f(y) == 1");

assertSimplificationSteps(VCSimplification::simplifyOnce, implication,
chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("f(y) == 1")),
chain(expect("f(x) == 0"), expect("f(y) == 0 + 1", "f(x) == 0"), expect("0 + 1 == 1", "f(y) == 0 + 1")),
chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("0 + 1 == 1", "f(y) == 0 + 1")),
chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("1 == 1", "0 + 1 == 1")),
chain(expect("f(x) == 0"), expect("f(y) == 1", "f(y) == 0 + 1"), expect("true", "1 == 1")));
}

@Test
void simplifyCombinesSubstitutionAndNestedFoldingAcrossFixedPoint() {
VCImplication implication = vc("∀x:int. x == 1", "∀y:int. y == x + 2", "y - 1 == 2");
Expand Down
Loading