From 91b51fe38b7fe069879d7507b557a967a5bd7e0a Mon Sep 17 00:00:00 2001 From: Malte Tammena Date: Tue, 14 May 2024 16:11:06 +0200 Subject: [PATCH] continue with scripts and debugging --- .gitignore | 3 + Cargo.lock | 204 ++++++++++++++++++++++++++++++- Cargo.toml | 6 + flake.nix | 40 +++--- scripts/__init__.py | 0 scripts/aba_generator_acyclic.py | 92 +++++++------- scripts/decode-result-folder.py | 62 ++++++++++ scripts/generate_nn.py | 69 +++++++++++ scripts/run-model.py | 33 +++++ src/aba/mod.rs | 7 ++ src/aba/prepared.rs | 9 +- src/aba/problems/mod.rs | 12 +- src/main.rs | 1 + 13 files changed, 467 insertions(+), 71 deletions(-) create mode 100644 scripts/__init__.py mode change 100644 => 100755 scripts/aba_generator_acyclic.py create mode 100755 scripts/decode-result-folder.py create mode 100755 scripts/generate_nn.py create mode 100755 scripts/run-model.py diff --git a/.gitignore b/.gitignore index 2328d65..9753bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ result* target/ .pre-commit-config.yaml +scripts/__pycache__ +acyclic +output-* diff --git a/Cargo.lock b/Cargo.lock index a421eee..08d5840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,11 +8,23 @@ version = "0.1.0" dependencies = [ "cadical", "clap", + "fun_time", "lazy_static", + "log", "nom", + "pretty_env_logger", "thiserror", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anstream" version = "0.6.12" @@ -105,7 +117,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.11.0", "terminal_size", ] @@ -118,7 +130,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.49", ] [[package]] @@ -133,6 +145,54 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "errno" version = "0.3.8" @@ -143,12 +203,70 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "fun_time" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee194d43605ea83cca7af42af5f9001fab1a8e2220cb8a012e21dda6167fdb0" +dependencies = [ + "fun_time_derive", + "log", +] + +[[package]] +name = "fun_time_derive" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71555fd2db00938d82d29d8fa62a2ae80aed2c162c328d775f79e98d9212f013" +dependencies = [ + "darling", + "log", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "jobserver" version = "0.1.28" @@ -176,6 +294,12 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + [[package]] name = "memchr" version = "2.7.1" @@ -198,6 +322,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "pretty_env_logger" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c" +dependencies = [ + "env_logger", + "log", +] + [[package]] name = "proc-macro2" version = "1.0.78" @@ -216,6 +350,35 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + [[package]] name = "rustix" version = "0.38.31" @@ -229,12 +392,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.49" @@ -246,6 +426,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "terminal_size" version = "0.3.0" @@ -273,7 +462,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.49", ] [[package]] @@ -288,6 +477,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 0d72812..4c87556 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,12 @@ default-run = "aba2sat" [dependencies] cadical = "0.1.14" clap = { version = "4.4.8", features = ["wrap_help", "derive"] } +fun_time = { version = "0.3.4", optional = true, features = ["log"] } lazy_static = "1.4.0" +log = "0.4.21" nom = "7.1.3" +pretty_env_logger = "0.5.0" thiserror = "1.0.50" + +[features] +timing = ["dep:fun_time"] diff --git a/flake.nix b/flake.nix index 89fce12..4f8a5c0 100644 --- a/flake.nix +++ b/flake.nix @@ -185,27 +185,31 @@ drv = aba2sat; }; - devShells.default = craneLib.devShell { - # Inherit inputs from checks. - checks = self.checks.${system}; + devShells.default = let + python = pkgs.python3.withPackages (ps: [ps.torch ps.torchvision ps.psutil]); + in + craneLib.devShell { + # Inherit inputs from checks. + checks = self.checks.${system}; - RUST_LOG = "trace"; + RUST_LOG = "trace"; - inputsFrom = []; + inputsFrom = []; - packages = [ - pkgs.hyperfine - pkgs.lldb - pkgs.nil - pkgs.nodejs - pkgs.pre-commit - pkgs.pyright - pkgs.ruff-lsp - pkgs.shellcheck - pkgs.shfmt - self'.packages.aspforaba - ]; - }; + packages = [ + pkgs.hyperfine + pkgs.lldb + pkgs.nil + pkgs.nodejs + pkgs.pre-commit + pkgs.pyright + pkgs.ruff-lsp + pkgs.shellcheck + pkgs.shfmt + python + self'.packages.aspforaba + ]; + }; }; }; } diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/aba_generator_acyclic.py b/scripts/aba_generator_acyclic.py old mode 100644 new mode 100755 index cb1ce4d..045acb6 --- a/scripts/aba_generator_acyclic.py +++ b/scripts/aba_generator_acyclic.py @@ -1,9 +1,9 @@ +#!/usr/bin/env python3 + import random import argparse -import sys -def create_framework(n_sentences, n_assumptions, n_rules_per_head, - size_of_bodies, cycle_prob): +def create_framework(n_sentences, n_assumptions, n_rules_per_head, size_of_bodies, cycle_prob): """ Create a random framework. @@ -15,8 +15,8 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head, - max(size_of_bodies) <= n_sentences+n_assumptions """ - assumptions = ["a" + str(i) for i in range(n_assumptions)] - sentences = ["s" + str(i) for i in range(n_sentences-n_assumptions)] + assumptions = [str(i) for i in range(1,n_assumptions+1)] + sentences = [str(i) for i in range(n_assumptions+1,n_sentences+1)] contraries = {asmpt: random.choice(sentences+assumptions) for asmpt in assumptions} @@ -42,56 +42,54 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head, return assumptions, sentences, contraries, rules -def print_ASP(assumptions, contraries, rules, out_filename, query=None): +def print_ICCMA_format(assumptions, contraries, rules, n_sentences, out_filename): """ - Print the given framework in ASP format. + Print the given framework in the ICCMA 2023 format. """ + offset = len(assumptions) + with open(out_filename, 'w') as out: - for asm in assumptions: - out.write("assumption(" + asm + ").\n") + out.write(f"p aba {n_sentences}\n") + for i, asm in enumerate(assumptions): + out.write(f"a {asm}\n") + #print(f"a {asm}") for ctr in contraries: - out.write("contrary(" + ctr + "," + contraries.get(ctr) + ").\n") - for i, rule in enumerate(rules): - out.write("head(" + str(i) + "," + rule[0] + ").\n") - if rule[1]: - for body in rule[1]: - out.write("body(" + str(i) + "," + body + ").\n") - if query: - out.write("query(" + query + ").") - -n_sentences = int(sys.argv[1]) -cycle_prob = float(sys.argv[2]) -max_rules_per_head = 5 -max_body_size = 5 -n_a = int(round(0.15*n_sentences)) -n_rph = range(1,max_rules_per_head+1) -n_spb = range(1,max_body_size) - -framework = create_framework(n_sentences, n_a, n_rph, n_spb, cycle_prob) -print_ASP(framework[0], framework[2], framework[3], "generated_benchmark.asp", "s0") + out.write(f"c {ctr} {contraries.get(ctr)}\n") + #print(f"c {ctr} {contraries.get(ctr)}") + for rule in rules: + out.write(f"r {rule[0]} {' '.join(rule[1])}\n") + #print(f"r {rule[0]} {' '.join(rule[1])}") +def ICCMA23_benchmarks(sentences=[1000,2000,3000,4000,5000], max_rules_per_head_list=[5,10], max_rule_size_list=[5,10], assumption_ratios=[0.1,0.3], count=10, directory="iccma23_aba_benchmarks", identifier="aba"): + random.seed(811543731122527) + for sentence in sentences: + for assumption_ratio in assumption_ratios: + for max_rules_per_head in max_rules_per_head_list: + for max_rule_size in max_rule_size_list: + for i in range(count): + number_assumptions = int(round(assumption_ratio*sentence)) + number_rules_per_head = range(1,max_rules_per_head+1) + n_spb = range(1,max_rule_size+1) + filename = f"{directory}/{identifier}_{sentence}_{assumption_ratio}_{max_rules_per_head}_{max_rule_size}_{i}.aba" + print(filename) + framework = create_framework(sentence, number_assumptions, number_rules_per_head, n_spb, 0) + query = random.randint(1,number_assumptions) + with open(f"{filename}.asm", 'w') as out: + print(f"{filename}.asm") + out.write(f"{query}") + print_ICCMA_format(framework[0], framework[2], framework[3], sentence, filename) parser = argparse.ArgumentParser() parser.add_argument('-d', '--directory') parser.add_argument('-i', '--identifier') args = parser.parse_args() -directory = args.directory -identifier = args.identifier - -sens = [1000,2000,3000,4000,5000] -n_rules_max = [2,5,8,13] -rule_size_max = [2,5,8,13] -asmpt_ratio = [0.15,0.3,0.7] -for sen in sens: - for k in asmpt_ratio: - for rph_max in n_rules_max: - for spb_max in rule_size_max: - for i in range(10): - n_a = int(round(k*sen)) - n_rph = range(1,rph_max+1) - n_spb = range(1,spb_max+1) - filename = f"{directory}/{identifier}_{sen}_{k}_{rph_max}_{spb_max}_{i}.asp" - print(filename) - framework = create_framework(sen, n_a, n_rph, n_spb) - print_ASP(framework[0], framework[2], framework[3], filename) +ICCMA23_benchmarks( + sentences = [50,100,200,300,400,500,1000,2000], + max_rules_per_head_list = [1,2,4,8,16], + max_rule_size_list = [1,2,4,8,16], + assumption_ratios = [0.1,0.3,0.5,0.7,0.9], + count = 5, + directory=args.directory if args.directory is not None else "acyclic", + identifier=args.identifier if args.identifier is not None else "aba", +) diff --git a/scripts/decode-result-folder.py b/scripts/decode-result-folder.py new file mode 100755 index 0000000..2b07fc0 --- /dev/null +++ b/scripts/decode-result-folder.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +import glob +import json +import os +import argparse +import csv + +parser = argparse.ArgumentParser() +parser.add_argument("-d", "--directory") +parser.add_argument("-o", "--output") +args = parser.parse_args() + +# Path to the folder +folder_path = args.directory if args.directory is not None else "output" +output = args.output if args.output is not None else "all.csv" + + +def run(): + count = 0 + out = [] + # Using glob to match all .json files + for file_path in glob.glob(os.path.join(folder_path, "*.json")): + # Open and read the contents of the file + with open(file_path, "r", encoding="utf-8") as json_file: + ( + _ident, + atom_count, + assumption_ratio, + max_rules_per_head, + max_rule_size, + _idx, + ) = file_path.split("_") + data = json.load(json_file)["results"] + aba2sat, aspforaba = ( + (data[0], data[1]) + if ( + data[0]["command"] == "aba2sat" + and data[1]["command"] == "aspforaba" + ) + else (data[1], data[0]) + ) + speedup = float(aspforaba['mean']) / float(aba2sat['mean']) + out.append({ + "atom_count": atom_count, + "assumption_ratio": assumption_ratio, + "max_rules_per_head": max_rules_per_head, + "max_rule_size": max_rule_size, + "time": aba2sat["mean"], + "stddev": aba2sat['stddev'], + "speedup": speedup, + }) + if count > 700: + break + count += 1 + with open(output, 'w') as output_file: + output_file.write + writer = csv.DictWriter(output_file, fieldnames=out[0].keys()) + writer.writeheader() + writer.writerows(out) + +run() diff --git a/scripts/generate_nn.py b/scripts/generate_nn.py new file mode 100755 index 0000000..5a3edd1 --- /dev/null +++ b/scripts/generate_nn.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.optim as optim +import psutil +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +# Define the neural network architecture +class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.fc1 = nn.Linear(28*28, 16) # Input layer + self.fc2 = nn.Linear(16, 16) # Hidden layer 1 + self.fc3 = nn.Linear(16, 16) # Hidden layer 2 + self.fc4 = nn.Linear(16, 16) # Hidden layer 3 + self.fc5 = nn.Linear(16, 10) # Output layer + + def forward(self, x): + x = x.view(-1, 28*28) # Flatten the input images + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = torch.relu(self.fc3(x)) + x = torch.relu(self.fc4(x)) + x = self.fc5(x) + return x + +def run_training(): + # Get the number of physical cores + num_physical_cores = psutil.cpu_count(logical=False) + + torch.set_num_threads(num_physical_cores) + + # Load MNIST dataset + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) + train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) + train_loader = DataLoader(train_dataset, batch_size=2000, shuffle=True) + + # Initialize the neural network + model = NeuralNetwork() + + # Define loss function and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.01) + + # Training loop + epochs = 10 + for epoch in range(epochs): + running_loss = 0.0 + for i, data in enumerate(train_loader, 0): + inputs, labels = data + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() + if i % 10 == 9: # Print every 10 mini-batches + print('[%3d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10)) + running_loss = 0.0 + + torch.save(model.state_dict(), 'data/model.pth') + print('Finished Training') + +if __name__ == '__main__': + run_training() diff --git a/scripts/run-model.py b/scripts/run-model.py new file mode 100755 index 0000000..6d1bf73 --- /dev/null +++ b/scripts/run-model.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import torch +from PIL import Image +import torchvision.transforms as transforms +from generate_nn import NeuralNetwork + +model = NeuralNetwork() +model.load_state_dict(torch.load('data/model.pth')) +model.eval() + +# Define a transform to convert the image to tensor and normalize it +# MNIST images are usually transformed to tensors and normalized with a mean of 0.1307 and std of 0.3081 +transform = transforms.Compose([ + transforms.Grayscale(num_output_channels=1), # Convert to grayscale if the image is in color + transforms.Resize((28, 28)), # Resize image to 28x28 pixels + transforms.ToTensor(), # Convert to PyTorch Tensor + transforms.Normalize((0.1307,), (0.3081,)), # Normalize pixel values +]) + +# Load the image +image = Image.open('test.png') + +# Apply the transform to the image +image = transform(image) + +# Add an extra batch dimension since PyTorch treats all inputs as batches +image = torch.unsqueeze(image, 0) + +with torch.no_grad(): + output = model(image) + predicted_class = output.argmax(dim = 1).item() + print('Predicted class:', predicted_class) diff --git a/src/aba/mod.rs b/src/aba/mod.rs index 2019519..510cf3c 100644 --- a/src/aba/mod.rs +++ b/src/aba/mod.rs @@ -101,6 +101,13 @@ impl Aba { } /// Prepare this aba for translation to SAT + #[cfg_attr( + feature = "timing", + fun_time::fun_time( + message = "Preparing ABA with max {max_loops:?} loops", + reporting = "log" + ) + )] pub fn prepare(self, max_loops: Option) -> PreparedAba { PreparedAba::new(self, max_loops) } diff --git a/src/aba/prepared.rs b/src/aba/prepared.rs index 0e012f5..398f320 100644 --- a/src/aba/prepared.rs +++ b/src/aba/prepared.rs @@ -27,7 +27,10 @@ impl PreparedAba { /// Create a new [`PreparedAba`] from a raw [`Aba`] pub fn new(mut aba: Aba, max_loops: Option) -> Self { trim_unreachable_rules(&mut aba); - let loops = calculate_loops_and_their_support(&aba, max_loops).collect(); + let loops = match max_loops { + Some(0) => vec![], + _ => calculate_loops_and_their_support(&aba, max_loops).collect(), + }; PreparedAba { aba, loops } } /// Translate the ABA into base rules / definitions for SAT solving @@ -138,6 +141,10 @@ impl PreparedAba { /// Iterates over all rules, marking reachable elements until /// no additional rule can be applied. Then removes every /// rule that contains any unreachable atom and returns the rest +#[cfg_attr( + feature = "timing", + fun_time::fun_time(message = "Triming unnecessary rules from ABA", reporting = "log") +)] fn trim_unreachable_rules(aba: &mut Aba) { // Begin with all assumptions marked as reachable let mut reachable: HashSet<_> = aba.assumptions().cloned().collect(); diff --git a/src/aba/problems/mod.rs b/src/aba/problems/mod.rs index 5c4c202..b05f0e9 100644 --- a/src/aba/problems/mod.rs +++ b/src/aba/problems/mod.rs @@ -70,7 +70,7 @@ pub fn solve(problem: P, aba: Aba, max_loops: Option) -> Resu map.as_raw_iter(&additional_clauses) .for_each(|raw| sat.add_clause(raw)); // A single solver call to determine the solution - if let Some(sat_result) = sat.solve() { + if let Some(sat_result) = call_sat_solver(&mut sat) { #[cfg(debug_assertions)] if sat_result { let rec = map.reconstruct(&sat).collect::>(); @@ -119,7 +119,7 @@ pub fn multishot_solve( map.as_raw_iter(&additional_clauses) .for_each(|raw| sat.add_clause(raw)); // Call the solver for the next result - let sat_result = sat.solve().ok_or(Error::SatCallInterrupted)?; + let sat_result = call_sat_solver(&mut sat).ok_or(Error::SatCallInterrupted)?; #[cfg(debug_assertions)] if sat_result { let rec = map.reconstruct(&sat).collect::>(); @@ -154,3 +154,11 @@ pub fn multishot_solve( iteration, )) } + +#[cfg_attr( + feature = "timing", + fun_time::fun_time(message = "Calling SAT solver", reporting = "log") +)] +fn call_sat_solver(sat: &mut Solver) -> Option { + sat.solve() +} diff --git a/src/main.rs b/src/main.rs index 8b5fbd5..dfd8e65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,6 +29,7 @@ trait IccmaFormattable { } fn __main() -> Result { + pretty_env_logger::init(); let args = match &*ARGS { Some(args) => args, None => {