|
| 1 | +# SMT Layer |
| 2 | +Implementation of SMTLayer in Pytorch with Z3 |
| 3 | +This code is the python companion to **Grounding Neural Inference with Satisfiability Modulo Theories**, which is appearing in NeurIPS 2023. |
| 4 | + |
| 5 | + In this paper we present a set of techniques for integrating Satisfiability Modulo Theories (SMT) solvers into the forward and backward passes of a deep network layer, called SMTLayer. **Notably, the solver needs not be differentiable.** We implement SMTLayer as a Pytorch module. An overview of our work is shown as follows. |
| 6 | + |
| 7 | +<img width="1026" alt="smt_layer" src="https://github.com/cmu-transparency/smt-layer/assets/9357853/35ae93dc-af5d-4f91-82de-518dd7434faa"> |
| 8 | + |
| 9 | +SMTLayers, when used on top of other neural network layers, can be leveraged to solve many tasks requiring logical reasoning. For example, the addition of two digits in an image. Morever, we show how to do visual Sudoku, Lier's Puzzle (above), etc. |
| 10 | +<img width="660" alt="example" src="https://github.com/cmu-transparency/smt-layer/assets/9357853/ac36a246-9322-41ef-9096-cdd2916e76ee"> |
| 11 | + |
| 12 | + |
| 13 | +We implement SMTLayer as a Pytorch module, and our empirical results show that it leads to models that 1) require fewer training samples than conventional models, 2) that are robust to certain types of covariate shift, and 3) that ultimately learn representations that are consistent with symbolic knowledge, and thus naturally interpretable. |
| 14 | + |
| 15 | +<img width="1000" alt="table" src="https://github.com/cmu-transparency/smt-layer/assets/9357853/4afe0cf3-043f-4945-902f-ee600a38d23b"> |
| 16 | + |
| 17 | + |
| 18 | +## Prerequisites |
| 19 | + |
| 20 | +You may need the following dependencies installed on your system: |
| 21 | + |
| 22 | +- Python (3.6+) |
| 23 | +- PyTorch |
| 24 | +- torchvision |
| 25 | +- matplotlib |
| 26 | +- livelossplot |
| 27 | +- z3 |
| 28 | + |
| 29 | +Depending on your operating system, you can install z3 using different methods. Here's how to install z3 using `pip`: |
| 30 | + |
| 31 | +```bash |
| 32 | +pip install z3-solver |
| 33 | +``` |
| 34 | + |
| 35 | +Otherwise follow the instructions [here](https://github.com/Z3Prover/z3). |
| 36 | + |
| 37 | + |
| 38 | +## Data Preparation |
| 39 | +Datasets used in the paper are included in `notebook/` execept MNIST, which is publicly available. For example, you can download from Pytorch. |
| 40 | +```python |
| 41 | +import torchvision |
| 42 | +from torchvision import transforms |
| 43 | + |
| 44 | +# Define a transformation |
| 45 | +transform = transforms.Compose([transforms.ToTensor()]) |
| 46 | + |
| 47 | +# Download and transform the training dataset |
| 48 | +mnist_train = torchvision.datasets.MNIST( |
| 49 | + '/data/data', |
| 50 | + train=True, |
| 51 | + download=True, |
| 52 | + transform=transform |
| 53 | +) |
| 54 | + |
| 55 | +# Download and transform the test dataset |
| 56 | +mnist_test = torchvision.datasets.MNIST( |
| 57 | + '/data/data', |
| 58 | + train=False, |
| 59 | + download=True, |
| 60 | + transform=transform |
| 61 | +) |
| 62 | +``` |
| 63 | + |
| 64 | +Replace '/data/data' with the appropriate path where you want to store the dataset. |
| 65 | + |
| 66 | + |
| 67 | +## Acknowledgments |
| 68 | + |
| 69 | +If you use this code please cite |
| 70 | +``` |
| 71 | +@inproceedings{ |
| 72 | + wang2023grounding, |
| 73 | + title={Grounding Neural Inference with Satisfiability Modulo Theories}, |
| 74 | + author={Zifan Wang and Saranya Vijaykumar and Kaiji Lu and Vijay Ganesh and Somesh Jha and Matt Fredrikson}, |
| 75 | + booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, |
| 76 | + year={2023}, |
| 77 | + url={https://openreview.net/forum?id=r8snfquzs3} |
| 78 | +} |
| 79 | +``` |
0 commit comments