continue with scripts and debugging

This commit is contained in:
Malte Tammena 2024-05-14 16:11:06 +02:00
parent ef5f0606c0
commit 91b51fe38b
13 changed files with 467 additions and 71 deletions

3
.gitignore vendored
View file

@ -2,3 +2,6 @@
result*
target/
.pre-commit-config.yaml
scripts/__pycache__
acyclic
output-*

204
Cargo.lock generated
View file

@ -8,11 +8,23 @@ version = "0.1.0"
dependencies = [
"cadical",
"clap",
"fun_time",
"lazy_static",
"log",
"nom",
"pretty_env_logger",
"thiserror",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "anstream"
version = "0.6.12"
@ -105,7 +117,7 @@ dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
"strsim 0.11.0",
"terminal_size",
]
@ -118,7 +130,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
"syn 2.0.49",
]
[[package]]
@ -133,6 +145,54 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "darling"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.10.0",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
]
[[package]]
name = "env_logger"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
dependencies = [
"humantime",
"is-terminal",
"log",
"regex",
"termcolor",
]
[[package]]
name = "errno"
version = "0.3.8"
@ -143,12 +203,70 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "fun_time"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bee194d43605ea83cca7af42af5f9001fab1a8e2220cb8a012e21dda6167fdb0"
dependencies = [
"fun_time_derive",
"log",
]
[[package]]
name = "fun_time_derive"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71555fd2db00938d82d29d8fa62a2ae80aed2c162c328d775f79e98d9212f013"
dependencies = [
"darling",
"log",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "is-terminal"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "jobserver"
version = "0.1.28"
@ -176,6 +294,12 @@ version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "log"
version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "memchr"
version = "2.7.1"
@ -198,6 +322,16 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "pretty_env_logger"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c"
dependencies = [
"env_logger",
"log",
]
[[package]]
name = "proc-macro2"
version = "1.0.78"
@ -216,6 +350,35 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "regex"
version = "1.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]]
name = "rustix"
version = "0.38.31"
@ -229,12 +392,29 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "strsim"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.49"
@ -246,6 +426,15 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "terminal_size"
version = "0.3.0"
@ -273,7 +462,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.49",
]
[[package]]
@ -288,6 +477,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "winapi-util"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "windows-sys"
version = "0.48.0"

View file

@ -12,6 +12,12 @@ default-run = "aba2sat"
[dependencies]
cadical = "0.1.14"
clap = { version = "4.4.8", features = ["wrap_help", "derive"] }
fun_time = { version = "0.3.4", optional = true, features = ["log"] }
lazy_static = "1.4.0"
log = "0.4.21"
nom = "7.1.3"
pretty_env_logger = "0.5.0"
thiserror = "1.0.50"
[features]
timing = ["dep:fun_time"]

View file

@ -185,27 +185,31 @@
drv = aba2sat;
};
devShells.default = craneLib.devShell {
# Inherit inputs from checks.
checks = self.checks.${system};
devShells.default = let
python = pkgs.python3.withPackages (ps: [ps.torch ps.torchvision ps.psutil]);
in
craneLib.devShell {
# Inherit inputs from checks.
checks = self.checks.${system};
RUST_LOG = "trace";
RUST_LOG = "trace";
inputsFrom = [];
inputsFrom = [];
packages = [
pkgs.hyperfine
pkgs.lldb
pkgs.nil
pkgs.nodejs
pkgs.pre-commit
pkgs.pyright
pkgs.ruff-lsp
pkgs.shellcheck
pkgs.shfmt
self'.packages.aspforaba
];
};
packages = [
pkgs.hyperfine
pkgs.lldb
pkgs.nil
pkgs.nodejs
pkgs.pre-commit
pkgs.pyright
pkgs.ruff-lsp
pkgs.shellcheck
pkgs.shfmt
python
self'.packages.aspforaba
];
};
};
};
}

0
scripts/__init__.py Normal file
View file

92
scripts/aba_generator_acyclic.py Normal file → Executable file
View file

@ -1,9 +1,9 @@
#!/usr/bin/env python3
import random
import argparse
import sys
def create_framework(n_sentences, n_assumptions, n_rules_per_head,
size_of_bodies, cycle_prob):
def create_framework(n_sentences, n_assumptions, n_rules_per_head, size_of_bodies, cycle_prob):
"""
Create a random framework.
@ -15,8 +15,8 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head,
- max(size_of_bodies) <= n_sentences+n_assumptions
"""
assumptions = ["a" + str(i) for i in range(n_assumptions)]
sentences = ["s" + str(i) for i in range(n_sentences-n_assumptions)]
assumptions = [str(i) for i in range(1,n_assumptions+1)]
sentences = [str(i) for i in range(n_assumptions+1,n_sentences+1)]
contraries = {asmpt: random.choice(sentences+assumptions) for asmpt in assumptions}
@ -42,56 +42,54 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head,
return assumptions, sentences, contraries, rules
def print_ASP(assumptions, contraries, rules, out_filename, query=None):
def print_ICCMA_format(assumptions, contraries, rules, n_sentences, out_filename):
"""
Print the given framework in ASP format.
Print the given framework in the ICCMA 2023 format.
"""
offset = len(assumptions)
with open(out_filename, 'w') as out:
for asm in assumptions:
out.write("assumption(" + asm + ").\n")
out.write(f"p aba {n_sentences}\n")
for i, asm in enumerate(assumptions):
out.write(f"a {asm}\n")
#print(f"a {asm}")
for ctr in contraries:
out.write("contrary(" + ctr + "," + contraries.get(ctr) + ").\n")
for i, rule in enumerate(rules):
out.write("head(" + str(i) + "," + rule[0] + ").\n")
if rule[1]:
for body in rule[1]:
out.write("body(" + str(i) + "," + body + ").\n")
if query:
out.write("query(" + query + ").")
n_sentences = int(sys.argv[1])
cycle_prob = float(sys.argv[2])
max_rules_per_head = 5
max_body_size = 5
n_a = int(round(0.15*n_sentences))
n_rph = range(1,max_rules_per_head+1)
n_spb = range(1,max_body_size)
framework = create_framework(n_sentences, n_a, n_rph, n_spb, cycle_prob)
print_ASP(framework[0], framework[2], framework[3], "generated_benchmark.asp", "s0")
out.write(f"c {ctr} {contraries.get(ctr)}\n")
#print(f"c {ctr} {contraries.get(ctr)}")
for rule in rules:
out.write(f"r {rule[0]} {' '.join(rule[1])}\n")
#print(f"r {rule[0]} {' '.join(rule[1])}")
def ICCMA23_benchmarks(sentences=[1000,2000,3000,4000,5000], max_rules_per_head_list=[5,10], max_rule_size_list=[5,10], assumption_ratios=[0.1,0.3], count=10, directory="iccma23_aba_benchmarks", identifier="aba"):
random.seed(811543731122527)
for sentence in sentences:
for assumption_ratio in assumption_ratios:
for max_rules_per_head in max_rules_per_head_list:
for max_rule_size in max_rule_size_list:
for i in range(count):
number_assumptions = int(round(assumption_ratio*sentence))
number_rules_per_head = range(1,max_rules_per_head+1)
n_spb = range(1,max_rule_size+1)
filename = f"{directory}/{identifier}_{sentence}_{assumption_ratio}_{max_rules_per_head}_{max_rule_size}_{i}.aba"
print(filename)
framework = create_framework(sentence, number_assumptions, number_rules_per_head, n_spb, 0)
query = random.randint(1,number_assumptions)
with open(f"{filename}.asm", 'w') as out:
print(f"{filename}.asm")
out.write(f"{query}")
print_ICCMA_format(framework[0], framework[2], framework[3], sentence, filename)
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--directory')
parser.add_argument('-i', '--identifier')
args = parser.parse_args()
directory = args.directory
identifier = args.identifier
sens = [1000,2000,3000,4000,5000]
n_rules_max = [2,5,8,13]
rule_size_max = [2,5,8,13]
asmpt_ratio = [0.15,0.3,0.7]
for sen in sens:
for k in asmpt_ratio:
for rph_max in n_rules_max:
for spb_max in rule_size_max:
for i in range(10):
n_a = int(round(k*sen))
n_rph = range(1,rph_max+1)
n_spb = range(1,spb_max+1)
filename = f"{directory}/{identifier}_{sen}_{k}_{rph_max}_{spb_max}_{i}.asp"
print(filename)
framework = create_framework(sen, n_a, n_rph, n_spb)
print_ASP(framework[0], framework[2], framework[3], filename)
ICCMA23_benchmarks(
sentences = [50,100,200,300,400,500,1000,2000],
max_rules_per_head_list = [1,2,4,8,16],
max_rule_size_list = [1,2,4,8,16],
assumption_ratios = [0.1,0.3,0.5,0.7,0.9],
count = 5,
directory=args.directory if args.directory is not None else "acyclic",
identifier=args.identifier if args.identifier is not None else "aba",
)

62
scripts/decode-result-folder.py Executable file
View file

@ -0,0 +1,62 @@
#!/usr/bin/env python3
import glob
import json
import os
import argparse
import csv
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--directory")
parser.add_argument("-o", "--output")
args = parser.parse_args()
# Path to the folder
folder_path = args.directory if args.directory is not None else "output"
output = args.output if args.output is not None else "all.csv"
def run():
count = 0
out = []
# Using glob to match all .json files
for file_path in glob.glob(os.path.join(folder_path, "*.json")):
# Open and read the contents of the file
with open(file_path, "r", encoding="utf-8") as json_file:
(
_ident,
atom_count,
assumption_ratio,
max_rules_per_head,
max_rule_size,
_idx,
) = file_path.split("_")
data = json.load(json_file)["results"]
aba2sat, aspforaba = (
(data[0], data[1])
if (
data[0]["command"] == "aba2sat"
and data[1]["command"] == "aspforaba"
)
else (data[1], data[0])
)
speedup = float(aspforaba['mean']) / float(aba2sat['mean'])
out.append({
"atom_count": atom_count,
"assumption_ratio": assumption_ratio,
"max_rules_per_head": max_rules_per_head,
"max_rule_size": max_rule_size,
"time": aba2sat["mean"],
"stddev": aba2sat['stddev'],
"speedup": speedup,
})
if count > 700:
break
count += 1
with open(output, 'w') as output_file:
output_file.write
writer = csv.DictWriter(output_file, fieldnames=out[0].keys())
writer.writeheader()
writer.writerows(out)
run()

69
scripts/generate_nn.py Executable file
View file

@ -0,0 +1,69 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.optim as optim
import psutil
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define the neural network architecture
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.fc1 = nn.Linear(28*28, 16) # Input layer
self.fc2 = nn.Linear(16, 16) # Hidden layer 1
self.fc3 = nn.Linear(16, 16) # Hidden layer 2
self.fc4 = nn.Linear(16, 16) # Hidden layer 3
self.fc5 = nn.Linear(16, 10) # Output layer
def forward(self, x):
x = x.view(-1, 28*28) # Flatten the input images
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = torch.relu(self.fc4(x))
x = self.fc5(x)
return x
def run_training():
# Get the number of physical cores
num_physical_cores = psutil.cpu_count(logical=False)
torch.set_num_threads(num_physical_cores)
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2000, shuffle=True)
# Initialize the neural network
model = NeuralNetwork()
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Training loop
epochs = 10
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 10 == 9: # Print every 10 mini-batches
print('[%3d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
running_loss = 0.0
torch.save(model.state_dict(), 'data/model.pth')
print('Finished Training')
if __name__ == '__main__':
run_training()

33
scripts/run-model.py Executable file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import torch
from PIL import Image
import torchvision.transforms as transforms
from generate_nn import NeuralNetwork
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
# Define a transform to convert the image to tensor and normalize it
# MNIST images are usually transformed to tensors and normalized with a mean of 0.1307 and std of 0.3081
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # Convert to grayscale if the image is in color
transforms.Resize((28, 28)), # Resize image to 28x28 pixels
transforms.ToTensor(), # Convert to PyTorch Tensor
transforms.Normalize((0.1307,), (0.3081,)), # Normalize pixel values
])
# Load the image
image = Image.open('test.png')
# Apply the transform to the image
image = transform(image)
# Add an extra batch dimension since PyTorch treats all inputs as batches
image = torch.unsqueeze(image, 0)
with torch.no_grad():
output = model(image)
predicted_class = output.argmax(dim = 1).item()
print('Predicted class:', predicted_class)

View file

@ -101,6 +101,13 @@ impl Aba {
}
/// Prepare this aba for translation to SAT
#[cfg_attr(
feature = "timing",
fun_time::fun_time(
message = "Preparing ABA with max {max_loops:?} loops",
reporting = "log"
)
)]
pub fn prepare(self, max_loops: Option<usize>) -> PreparedAba {
PreparedAba::new(self, max_loops)
}

View file

@ -27,7 +27,10 @@ impl PreparedAba {
/// Create a new [`PreparedAba`] from a raw [`Aba`]
pub fn new(mut aba: Aba, max_loops: Option<usize>) -> Self {
trim_unreachable_rules(&mut aba);
let loops = calculate_loops_and_their_support(&aba, max_loops).collect();
let loops = match max_loops {
Some(0) => vec![],
_ => calculate_loops_and_their_support(&aba, max_loops).collect(),
};
PreparedAba { aba, loops }
}
/// Translate the ABA into base rules / definitions for SAT solving
@ -138,6 +141,10 @@ impl PreparedAba {
/// Iterates over all rules, marking reachable elements until
/// no additional rule can be applied. Then removes every
/// rule that contains any unreachable atom and returns the rest
#[cfg_attr(
feature = "timing",
fun_time::fun_time(message = "Triming unnecessary rules from ABA", reporting = "log")
)]
fn trim_unreachable_rules(aba: &mut Aba) {
// Begin with all assumptions marked as reachable
let mut reachable: HashSet<_> = aba.assumptions().cloned().collect();

View file

@ -70,7 +70,7 @@ pub fn solve<P: Problem>(problem: P, aba: Aba, max_loops: Option<usize>) -> Resu
map.as_raw_iter(&additional_clauses)
.for_each(|raw| sat.add_clause(raw));
// A single solver call to determine the solution
if let Some(sat_result) = sat.solve() {
if let Some(sat_result) = call_sat_solver(&mut sat) {
#[cfg(debug_assertions)]
if sat_result {
let rec = map.reconstruct(&sat).collect::<Vec<_>>();
@ -119,7 +119,7 @@ pub fn multishot_solve<P: MultishotProblem>(
map.as_raw_iter(&additional_clauses)
.for_each(|raw| sat.add_clause(raw));
// Call the solver for the next result
let sat_result = sat.solve().ok_or(Error::SatCallInterrupted)?;
let sat_result = call_sat_solver(&mut sat).ok_or(Error::SatCallInterrupted)?;
#[cfg(debug_assertions)]
if sat_result {
let rec = map.reconstruct(&sat).collect::<Vec<_>>();
@ -154,3 +154,11 @@ pub fn multishot_solve<P: MultishotProblem>(
iteration,
))
}
#[cfg_attr(
feature = "timing",
fun_time::fun_time(message = "Calling SAT solver", reporting = "log")
)]
fn call_sat_solver(sat: &mut Solver) -> Option<bool> {
sat.solve()
}

View file

@ -29,6 +29,7 @@ trait IccmaFormattable {
}
fn __main() -> Result {
pretty_env_logger::init();
let args = match &*ARGS {
Some(args) => args,
None => {