continue with scripts and debugging
This commit is contained in:
parent
ef5f0606c0
commit
91b51fe38b
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -2,3 +2,6 @@
|
|||
result*
|
||||
target/
|
||||
.pre-commit-config.yaml
|
||||
scripts/__pycache__
|
||||
acyclic
|
||||
output-*
|
||||
|
|
204
Cargo.lock
generated
204
Cargo.lock
generated
|
@ -8,11 +8,23 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"cadical",
|
||||
"clap",
|
||||
"fun_time",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"nom",
|
||||
"pretty_env_logger",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.12"
|
||||
|
@ -105,7 +117,7 @@ dependencies = [
|
|||
"anstream",
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
"strsim 0.11.0",
|
||||
"terminal_size",
|
||||
]
|
||||
|
||||
|
@ -118,7 +130,7 @@ dependencies = [
|
|||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.49",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -133,6 +145,54 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.14.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.10.0",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.14.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
|
||||
dependencies = [
|
||||
"humantime",
|
||||
"is-terminal",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.8"
|
||||
|
@ -143,12 +203,70 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "fun_time"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bee194d43605ea83cca7af42af5f9001fab1a8e2220cb8a012e21dda6167fdb0"
|
||||
dependencies = [
|
||||
"fun_time_derive",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fun_time_derive"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71555fd2db00938d82d29d8fa62a2ae80aed2c162c328d775f79e98d9212f013"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.28"
|
||||
|
@ -176,6 +294,12 @@ version = "0.4.13"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.1"
|
||||
|
@ -198,6 +322,16 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pretty_env_logger"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c"
|
||||
dependencies = [
|
||||
"env_logger",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.78"
|
||||
|
@ -216,6 +350,35 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.31"
|
||||
|
@ -229,12 +392,29 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.49"
|
||||
|
@ -246,6 +426,15 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "terminal_size"
|
||||
version = "0.3.0"
|
||||
|
@ -273,7 +462,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.49",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -288,6 +477,15 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.48.0"
|
||||
|
|
|
@ -12,6 +12,12 @@ default-run = "aba2sat"
|
|||
[dependencies]
|
||||
cadical = "0.1.14"
|
||||
clap = { version = "4.4.8", features = ["wrap_help", "derive"] }
|
||||
fun_time = { version = "0.3.4", optional = true, features = ["log"] }
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.21"
|
||||
nom = "7.1.3"
|
||||
pretty_env_logger = "0.5.0"
|
||||
thiserror = "1.0.50"
|
||||
|
||||
[features]
|
||||
timing = ["dep:fun_time"]
|
||||
|
|
|
@ -185,7 +185,10 @@
|
|||
drv = aba2sat;
|
||||
};
|
||||
|
||||
devShells.default = craneLib.devShell {
|
||||
devShells.default = let
|
||||
python = pkgs.python3.withPackages (ps: [ps.torch ps.torchvision ps.psutil]);
|
||||
in
|
||||
craneLib.devShell {
|
||||
# Inherit inputs from checks.
|
||||
checks = self.checks.${system};
|
||||
|
||||
|
@ -203,6 +206,7 @@
|
|||
pkgs.ruff-lsp
|
||||
pkgs.shellcheck
|
||||
pkgs.shfmt
|
||||
python
|
||||
self'.packages.aspforaba
|
||||
];
|
||||
};
|
||||
|
|
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
92
scripts/aba_generator_acyclic.py
Normal file → Executable file
92
scripts/aba_generator_acyclic.py
Normal file → Executable file
|
@ -1,9 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
def create_framework(n_sentences, n_assumptions, n_rules_per_head,
|
||||
size_of_bodies, cycle_prob):
|
||||
def create_framework(n_sentences, n_assumptions, n_rules_per_head, size_of_bodies, cycle_prob):
|
||||
"""
|
||||
Create a random framework.
|
||||
|
||||
|
@ -15,8 +15,8 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head,
|
|||
- max(size_of_bodies) <= n_sentences+n_assumptions
|
||||
"""
|
||||
|
||||
assumptions = ["a" + str(i) for i in range(n_assumptions)]
|
||||
sentences = ["s" + str(i) for i in range(n_sentences-n_assumptions)]
|
||||
assumptions = [str(i) for i in range(1,n_assumptions+1)]
|
||||
sentences = [str(i) for i in range(n_assumptions+1,n_sentences+1)]
|
||||
|
||||
contraries = {asmpt: random.choice(sentences+assumptions) for asmpt in assumptions}
|
||||
|
||||
|
@ -42,56 +42,54 @@ def create_framework(n_sentences, n_assumptions, n_rules_per_head,
|
|||
|
||||
return assumptions, sentences, contraries, rules
|
||||
|
||||
def print_ASP(assumptions, contraries, rules, out_filename, query=None):
|
||||
def print_ICCMA_format(assumptions, contraries, rules, n_sentences, out_filename):
|
||||
"""
|
||||
Print the given framework in ASP format.
|
||||
Print the given framework in the ICCMA 2023 format.
|
||||
"""
|
||||
offset = len(assumptions)
|
||||
|
||||
with open(out_filename, 'w') as out:
|
||||
for asm in assumptions:
|
||||
out.write("assumption(" + asm + ").\n")
|
||||
out.write(f"p aba {n_sentences}\n")
|
||||
for i, asm in enumerate(assumptions):
|
||||
out.write(f"a {asm}\n")
|
||||
#print(f"a {asm}")
|
||||
for ctr in contraries:
|
||||
out.write("contrary(" + ctr + "," + contraries.get(ctr) + ").\n")
|
||||
for i, rule in enumerate(rules):
|
||||
out.write("head(" + str(i) + "," + rule[0] + ").\n")
|
||||
if rule[1]:
|
||||
for body in rule[1]:
|
||||
out.write("body(" + str(i) + "," + body + ").\n")
|
||||
if query:
|
||||
out.write("query(" + query + ").")
|
||||
|
||||
n_sentences = int(sys.argv[1])
|
||||
cycle_prob = float(sys.argv[2])
|
||||
max_rules_per_head = 5
|
||||
max_body_size = 5
|
||||
n_a = int(round(0.15*n_sentences))
|
||||
n_rph = range(1,max_rules_per_head+1)
|
||||
n_spb = range(1,max_body_size)
|
||||
|
||||
framework = create_framework(n_sentences, n_a, n_rph, n_spb, cycle_prob)
|
||||
print_ASP(framework[0], framework[2], framework[3], "generated_benchmark.asp", "s0")
|
||||
out.write(f"c {ctr} {contraries.get(ctr)}\n")
|
||||
#print(f"c {ctr} {contraries.get(ctr)}")
|
||||
for rule in rules:
|
||||
out.write(f"r {rule[0]} {' '.join(rule[1])}\n")
|
||||
#print(f"r {rule[0]} {' '.join(rule[1])}")
|
||||
|
||||
def ICCMA23_benchmarks(sentences=[1000,2000,3000,4000,5000], max_rules_per_head_list=[5,10], max_rule_size_list=[5,10], assumption_ratios=[0.1,0.3], count=10, directory="iccma23_aba_benchmarks", identifier="aba"):
|
||||
random.seed(811543731122527)
|
||||
for sentence in sentences:
|
||||
for assumption_ratio in assumption_ratios:
|
||||
for max_rules_per_head in max_rules_per_head_list:
|
||||
for max_rule_size in max_rule_size_list:
|
||||
for i in range(count):
|
||||
number_assumptions = int(round(assumption_ratio*sentence))
|
||||
number_rules_per_head = range(1,max_rules_per_head+1)
|
||||
n_spb = range(1,max_rule_size+1)
|
||||
filename = f"{directory}/{identifier}_{sentence}_{assumption_ratio}_{max_rules_per_head}_{max_rule_size}_{i}.aba"
|
||||
print(filename)
|
||||
framework = create_framework(sentence, number_assumptions, number_rules_per_head, n_spb, 0)
|
||||
query = random.randint(1,number_assumptions)
|
||||
with open(f"{filename}.asm", 'w') as out:
|
||||
print(f"{filename}.asm")
|
||||
out.write(f"{query}")
|
||||
print_ICCMA_format(framework[0], framework[2], framework[3], sentence, filename)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--directory')
|
||||
parser.add_argument('-i', '--identifier')
|
||||
args = parser.parse_args()
|
||||
|
||||
directory = args.directory
|
||||
identifier = args.identifier
|
||||
|
||||
sens = [1000,2000,3000,4000,5000]
|
||||
n_rules_max = [2,5,8,13]
|
||||
rule_size_max = [2,5,8,13]
|
||||
asmpt_ratio = [0.15,0.3,0.7]
|
||||
for sen in sens:
|
||||
for k in asmpt_ratio:
|
||||
for rph_max in n_rules_max:
|
||||
for spb_max in rule_size_max:
|
||||
for i in range(10):
|
||||
n_a = int(round(k*sen))
|
||||
n_rph = range(1,rph_max+1)
|
||||
n_spb = range(1,spb_max+1)
|
||||
filename = f"{directory}/{identifier}_{sen}_{k}_{rph_max}_{spb_max}_{i}.asp"
|
||||
print(filename)
|
||||
framework = create_framework(sen, n_a, n_rph, n_spb)
|
||||
print_ASP(framework[0], framework[2], framework[3], filename)
|
||||
ICCMA23_benchmarks(
|
||||
sentences = [50,100,200,300,400,500,1000,2000],
|
||||
max_rules_per_head_list = [1,2,4,8,16],
|
||||
max_rule_size_list = [1,2,4,8,16],
|
||||
assumption_ratios = [0.1,0.3,0.5,0.7,0.9],
|
||||
count = 5,
|
||||
directory=args.directory if args.directory is not None else "acyclic",
|
||||
identifier=args.identifier if args.identifier is not None else "aba",
|
||||
)
|
||||
|
|
62
scripts/decode-result-folder.py
Executable file
62
scripts/decode-result-folder.py
Executable file
|
@ -0,0 +1,62 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import csv
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-d", "--directory")
|
||||
parser.add_argument("-o", "--output")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Path to the folder
|
||||
folder_path = args.directory if args.directory is not None else "output"
|
||||
output = args.output if args.output is not None else "all.csv"
|
||||
|
||||
|
||||
def run():
|
||||
count = 0
|
||||
out = []
|
||||
# Using glob to match all .json files
|
||||
for file_path in glob.glob(os.path.join(folder_path, "*.json")):
|
||||
# Open and read the contents of the file
|
||||
with open(file_path, "r", encoding="utf-8") as json_file:
|
||||
(
|
||||
_ident,
|
||||
atom_count,
|
||||
assumption_ratio,
|
||||
max_rules_per_head,
|
||||
max_rule_size,
|
||||
_idx,
|
||||
) = file_path.split("_")
|
||||
data = json.load(json_file)["results"]
|
||||
aba2sat, aspforaba = (
|
||||
(data[0], data[1])
|
||||
if (
|
||||
data[0]["command"] == "aba2sat"
|
||||
and data[1]["command"] == "aspforaba"
|
||||
)
|
||||
else (data[1], data[0])
|
||||
)
|
||||
speedup = float(aspforaba['mean']) / float(aba2sat['mean'])
|
||||
out.append({
|
||||
"atom_count": atom_count,
|
||||
"assumption_ratio": assumption_ratio,
|
||||
"max_rules_per_head": max_rules_per_head,
|
||||
"max_rule_size": max_rule_size,
|
||||
"time": aba2sat["mean"],
|
||||
"stddev": aba2sat['stddev'],
|
||||
"speedup": speedup,
|
||||
})
|
||||
if count > 700:
|
||||
break
|
||||
count += 1
|
||||
with open(output, 'w') as output_file:
|
||||
output_file.write
|
||||
writer = csv.DictWriter(output_file, fieldnames=out[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(out)
|
||||
|
||||
run()
|
69
scripts/generate_nn.py
Executable file
69
scripts/generate_nn.py
Executable file
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import psutil
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
# Define the neural network architecture
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.fc1 = nn.Linear(28*28, 16) # Input layer
|
||||
self.fc2 = nn.Linear(16, 16) # Hidden layer 1
|
||||
self.fc3 = nn.Linear(16, 16) # Hidden layer 2
|
||||
self.fc4 = nn.Linear(16, 16) # Hidden layer 3
|
||||
self.fc5 = nn.Linear(16, 10) # Output layer
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 28*28) # Flatten the input images
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = torch.relu(self.fc3(x))
|
||||
x = torch.relu(self.fc4(x))
|
||||
x = self.fc5(x)
|
||||
return x
|
||||
|
||||
def run_training():
|
||||
# Get the number of physical cores
|
||||
num_physical_cores = psutil.cpu_count(logical=False)
|
||||
|
||||
torch.set_num_threads(num_physical_cores)
|
||||
|
||||
# Load MNIST dataset
|
||||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
|
||||
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
|
||||
train_loader = DataLoader(train_dataset, batch_size=2000, shuffle=True)
|
||||
|
||||
# Initialize the neural network
|
||||
model = NeuralNetwork()
|
||||
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# Training loop
|
||||
epochs = 10
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
inputs, labels = data
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
if i % 10 == 9: # Print every 10 mini-batches
|
||||
print('[%3d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
|
||||
running_loss = 0.0
|
||||
|
||||
torch.save(model.state_dict(), 'data/model.pth')
|
||||
print('Finished Training')
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_training()
|
33
scripts/run-model.py
Executable file
33
scripts/run-model.py
Executable file
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from generate_nn import NeuralNetwork
|
||||
|
||||
model = NeuralNetwork()
|
||||
model.load_state_dict(torch.load('data/model.pth'))
|
||||
model.eval()
|
||||
|
||||
# Define a transform to convert the image to tensor and normalize it
|
||||
# MNIST images are usually transformed to tensors and normalized with a mean of 0.1307 and std of 0.3081
|
||||
transform = transforms.Compose([
|
||||
transforms.Grayscale(num_output_channels=1), # Convert to grayscale if the image is in color
|
||||
transforms.Resize((28, 28)), # Resize image to 28x28 pixels
|
||||
transforms.ToTensor(), # Convert to PyTorch Tensor
|
||||
transforms.Normalize((0.1307,), (0.3081,)), # Normalize pixel values
|
||||
])
|
||||
|
||||
# Load the image
|
||||
image = Image.open('test.png')
|
||||
|
||||
# Apply the transform to the image
|
||||
image = transform(image)
|
||||
|
||||
# Add an extra batch dimension since PyTorch treats all inputs as batches
|
||||
image = torch.unsqueeze(image, 0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
predicted_class = output.argmax(dim = 1).item()
|
||||
print('Predicted class:', predicted_class)
|
|
@ -101,6 +101,13 @@ impl Aba {
|
|||
}
|
||||
|
||||
/// Prepare this aba for translation to SAT
|
||||
#[cfg_attr(
|
||||
feature = "timing",
|
||||
fun_time::fun_time(
|
||||
message = "Preparing ABA with max {max_loops:?} loops",
|
||||
reporting = "log"
|
||||
)
|
||||
)]
|
||||
pub fn prepare(self, max_loops: Option<usize>) -> PreparedAba {
|
||||
PreparedAba::new(self, max_loops)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,10 @@ impl PreparedAba {
|
|||
/// Create a new [`PreparedAba`] from a raw [`Aba`]
|
||||
pub fn new(mut aba: Aba, max_loops: Option<usize>) -> Self {
|
||||
trim_unreachable_rules(&mut aba);
|
||||
let loops = calculate_loops_and_their_support(&aba, max_loops).collect();
|
||||
let loops = match max_loops {
|
||||
Some(0) => vec![],
|
||||
_ => calculate_loops_and_their_support(&aba, max_loops).collect(),
|
||||
};
|
||||
PreparedAba { aba, loops }
|
||||
}
|
||||
/// Translate the ABA into base rules / definitions for SAT solving
|
||||
|
@ -138,6 +141,10 @@ impl PreparedAba {
|
|||
/// Iterates over all rules, marking reachable elements until
|
||||
/// no additional rule can be applied. Then removes every
|
||||
/// rule that contains any unreachable atom and returns the rest
|
||||
#[cfg_attr(
|
||||
feature = "timing",
|
||||
fun_time::fun_time(message = "Triming unnecessary rules from ABA", reporting = "log")
|
||||
)]
|
||||
fn trim_unreachable_rules(aba: &mut Aba) {
|
||||
// Begin with all assumptions marked as reachable
|
||||
let mut reachable: HashSet<_> = aba.assumptions().cloned().collect();
|
||||
|
|
|
@ -70,7 +70,7 @@ pub fn solve<P: Problem>(problem: P, aba: Aba, max_loops: Option<usize>) -> Resu
|
|||
map.as_raw_iter(&additional_clauses)
|
||||
.for_each(|raw| sat.add_clause(raw));
|
||||
// A single solver call to determine the solution
|
||||
if let Some(sat_result) = sat.solve() {
|
||||
if let Some(sat_result) = call_sat_solver(&mut sat) {
|
||||
#[cfg(debug_assertions)]
|
||||
if sat_result {
|
||||
let rec = map.reconstruct(&sat).collect::<Vec<_>>();
|
||||
|
@ -119,7 +119,7 @@ pub fn multishot_solve<P: MultishotProblem>(
|
|||
map.as_raw_iter(&additional_clauses)
|
||||
.for_each(|raw| sat.add_clause(raw));
|
||||
// Call the solver for the next result
|
||||
let sat_result = sat.solve().ok_or(Error::SatCallInterrupted)?;
|
||||
let sat_result = call_sat_solver(&mut sat).ok_or(Error::SatCallInterrupted)?;
|
||||
#[cfg(debug_assertions)]
|
||||
if sat_result {
|
||||
let rec = map.reconstruct(&sat).collect::<Vec<_>>();
|
||||
|
@ -154,3 +154,11 @@ pub fn multishot_solve<P: MultishotProblem>(
|
|||
iteration,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
feature = "timing",
|
||||
fun_time::fun_time(message = "Calling SAT solver", reporting = "log")
|
||||
)]
|
||||
fn call_sat_solver(sat: &mut Solver) -> Option<bool> {
|
||||
sat.solve()
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ trait IccmaFormattable {
|
|||
}
|
||||
|
||||
fn __main() -> Result {
|
||||
pretty_env_logger::init();
|
||||
let args = match &*ARGS {
|
||||
Some(args) => args,
|
||||
None => {
|
||||
|
|
Loading…
Reference in a new issue