pycosat中慢速dnf到cnf

发布于 2025-01-10 02:57:03 字数 11202 浏览 1 评论 0原文

简短的问题

要为 pycosat 提供正确的输入,是否有有什么方法可以加快从 dnf 到 cnf 的计算速度,或者完全规避它?

详细问题

我一直在观看 Raymond 的此视频 Hettinger 关于现代求解器。我下载了代码,并为游戏 Towers 实现了一个解算器 在其中。下面我分享了这样做的代码。

示例塔谜题(已解决):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

我遇到的问题是从 dnf 到 cnf 的转换需要很长时间。假设您知道从某个视线可以看到 3 座塔。这导致该行中有 35 种可能的排列 1-5。

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

这是一个析取范式:多个 AND 语句的 OR。这需要转换为连接范式:多个 OR 语句的 AND。然而,这非常慢。在我的 Macbook Pro 上,它在 5 分钟后没有完成单行的 cnf 计算。对于整个拼图,此操作最多应完成 20 次(对于 5x5 网格)。

为了使计算机能够解决这个塔谜题,优化这段代码的最佳方法是什么?

此代码也可从此 Github 存储库获取。

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()

实际时间花费在该辅助函数中,来自视频附带的使用的辅助代码。

def from_dnf(groups) -> 'cnf':
    'Convert from or-of-ands to and-of-ors'
    cnf = {frozenset()}
    for group_index, group in enumerate(groups, start=1):
        print(f'Group {group_index}/{len(groups)}')
        nl = {frozenset([literal]): neg(literal) for literal in group}
        # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
        # The nl check skips over identities: {x, ~x, y} -> True
        cnf = {clause | literal for literal in nl for clause in cnf
               if nl[literal] not in clause}
        # The sc check removes clauses with superfluous terms:
        #     {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
        # Should this be left until the end?
        sc = min(cnf, key=len)  # XXX not deterministic
        cnf -= {clause for clause in cnf if clause > sc}
    return list(map(tuple, cnf))

使用 4x4 网格时 pyinstrument 的输出显示,此处的 cnf = { ... } 行是罪魁祸首:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:05:58  Samples:  146
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.515     CPU time: 0.506
/   _/                      v3.4.2

Program: ./src/towers.py

0.515 <module>  ../<string>:1
   [7 frames hidden]  .., runpy
      0.513 _run_code  runpy.py:62
      └─ 0.513 <module>  towers.py:1
         ├─ 0.501 display_solution  towers.py:64
         │  └─ 0.501 solution  towers.py:188
         │     ├─ 0.408 cnf  towers.py:101
         │     │  ├─ 0.397 from_dnf  sat_utils.py:65
         │     │  │  ├─ 0.329 <setcomp>  sat_utils.py:73
         │     │  │  ├─ 0.029 [self]
         │     │  │  ├─ 0.021 min  ../<built-in>:0
         │     │  │  │     [2 frames hidden]  ..
         │     │  │  └─ 0.016 <setcomp>  sat_utils.py:79
         │     │  └─ 0.009 [self]
         │     └─ 0.093 solve_one  sat_utils.py:53
         │        └─ 0.091 itersolve  sat_utils.py:43
         │           ├─ 0.064 translate  sat_utils.py:32
         │           │  ├─ 0.049 <listcomp>  sat_utils.py:39
         │           │  │  ├─ 0.028 [self]
         │           │  │  └─ 0.021 <listcomp>  sat_utils.py:39
         │           │  └─ 0.015 make_translate  sat_utils.py:12
         │           └─ 0.024 itersolve  ../<built-in>:0
         │                 [2 frames hidden]  ..
         └─ 0.009 <module>  typing.py:1
               [26 frames hidden]  typing, abc, ..

Question in short

To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?

Question in detail

I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.

Example Tower puzzle (solved):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).

What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?

This code is also available from this Github repository.

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()

The actual time is spent in this helper function from the used helper code that came along with the video.

def from_dnf(groups) -> 'cnf':
    'Convert from or-of-ands to and-of-ors'
    cnf = {frozenset()}
    for group_index, group in enumerate(groups, start=1):
        print(f'Group {group_index}/{len(groups)}')
        nl = {frozenset([literal]): neg(literal) for literal in group}
        # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
        # The nl check skips over identities: {x, ~x, y} -> True
        cnf = {clause | literal for literal in nl for clause in cnf
               if nl[literal] not in clause}
        # The sc check removes clauses with superfluous terms:
        #     {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
        # Should this be left until the end?
        sc = min(cnf, key=len)  # XXX not deterministic
        cnf -= {clause for clause in cnf if clause > sc}
    return list(map(tuple, cnf))

The output from pyinstrument when using a 4x4 grid shows that the line cnf = { ... } in here is the culprit:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:05:58  Samples:  146
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.515     CPU time: 0.506
/   _/                      v3.4.2

Program: ./src/towers.py

0.515 <module>  ../<string>:1
   [7 frames hidden]  .., runpy
      0.513 _run_code  runpy.py:62
      └─ 0.513 <module>  towers.py:1
         ├─ 0.501 display_solution  towers.py:64
         │  └─ 0.501 solution  towers.py:188
         │     ├─ 0.408 cnf  towers.py:101
         │     │  ├─ 0.397 from_dnf  sat_utils.py:65
         │     │  │  ├─ 0.329 <setcomp>  sat_utils.py:73
         │     │  │  ├─ 0.029 [self]
         │     │  │  ├─ 0.021 min  ../<built-in>:0
         │     │  │  │     [2 frames hidden]  ..
         │     │  │  └─ 0.016 <setcomp>  sat_utils.py:79
         │     │  └─ 0.009 [self]
         │     └─ 0.093 solve_one  sat_utils.py:53
         │        └─ 0.091 itersolve  sat_utils.py:43
         │           ├─ 0.064 translate  sat_utils.py:32
         │           │  ├─ 0.049 <listcomp>  sat_utils.py:39
         │           │  │  ├─ 0.028 [self]
         │           │  │  └─ 0.021 <listcomp>  sat_utils.py:39
         │           │  └─ 0.015 make_translate  sat_utils.py:12
         │           └─ 0.024 itersolve  ../<built-in>:0
         │                 [2 frames hidden]  ..
         └─ 0.009 <module>  typing.py:1
               [26 frames hidden]  typing, abc, ..

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

野稚 2025-01-17 02:57:03

首先,最好注意等价性和等可满足性之间的区别。一般来说,将任意布尔公式(例如 DNF 中的某些公式)转换为 CNF 可能会导致大小呈指数级增长。

这种爆炸是您的 from_dnf 方法的问题:每当您处理另一个产品术语时,该产品中的每个文字都需要当前 cnf 子句集的新副本(它将在每个子句中添加自己)。如果您有 n 个大小为 k 的乘积项,则增长为 O(k^n)

在你的例子中,n实际上是k!的函数。保留为乘积项的内容将被过滤为满足视图约束的项,但程序的总体运行时间大致在 O(k^f(k!)) 范围内。即使 f 以对数方式增长,这仍然是 O(k^(k lg k)) 并且不太理想!

因为您问的是“这是可满足的吗?”,所以您不需要一个等价公式,而只需要一个可满足公式。这是一些新公式,当且仅当原始公式是可满足的时,它是可满足的,但相同的分配可能无法满足该新公式。

例如,(a ∨ b)(a ∨ c) ∧ (Øb) 显然都是可满足的,因此它们是可等满足的。但是将 b 设置为 true 会满足第一个条件,而会证伪第二个条件,因此它们并不等效。此外,第一个甚至没有 c 作为变量,再次使其不等于第二个。

这种松弛足以用线性大小的平移代替这种指数爆炸。


关键思想是使用扩展变量。这些是新变量(即,公式中尚未存在),允许我们缩写表达式,因此我们最终不会在翻译中制作它们的多个副本。由于新变量不存在于原始变量中,因此我们将不再有等效的公式;但因为当且仅当表达式为真时变量才为真,所以它是可等满足的。

如果我们想使用x作为y的缩写,我们会声明x ≡ y。这与 x → yy → x 相同,与 (Øx ∨ y) ∧ (Øy ∨ x) 相同code>,它已经在 CNF 中了。

考虑乘积项的缩写:x ≡ (a ∧ b)。这是 x → (a ∧ b)(a ∧ b) → x,它是三个子句:(Øx ∨ a) ∧ (Øx∨b)∧(Øa∨Øb∨x)。一般来说,用 x 缩写 k 个文字的乘积项将产生 k 个二进制子句,表示 x 隐含其中的每一个,并且一个 (k+1)< /code>-子句表示它们一起意味着x。这在k 中是线性的。

要真正了解为什么这有帮助,请尝试将 (a ∧ b ∧ c) ∨ (d ∧ e ∧ f) ∨ (g ∧ h ∧ i) 转换为带有或不带有扩展变量的等效 CNF第一个乘积项。当然,我们不会仅仅停留在一个术语上:如果我们缩写每个术语,那么结果恰好是一个 CNF 子句:(x ∨ y ∨ z) 其中每个术语都缩写为一个乘积术语。这个小了很多!

这种方法可用于将任何电路转换为可等满足的公式,其大小和 CNF 呈线性。这称为Tseitin 变换。您的 DNF 公式只是一个由一堆任意扇入与门组成的电路,所有输入都馈入单个任意扇入或门。

最重要的是,尽管由于附加变量,该公式并不等效,但我们可以通过简单地删除扩展变量来恢复原始公式的赋值。它是一种“最佳情况”可等满足公式,是原始公式的严格超集。


为了将其修补到您的代码中,我添加了:

# Uses pseudo-namespacing to avoid collisions.
_EXT_SUFFIX = "___"
_NEXT_EXT_INDEX = 0


def is_ext_var(element) -> bool:
    return element.endswith(_EXT_SUFFIX)


def ext_var() -> str:
    global _NEXT_EXT_INDEX
    ext_index = _NEXT_EXT_INDEX
    _NEXT_EXT_INDEX += 1

    return intern(f"{ext_index}{_EXT_SUFFIX}")

这让我们可以凭空提取一个新的命名变量。由于这些扩展变量名称对您的解决方案显示功能没有有意义的语义,因此我将: 更改

point_to_value = {
    point: value for point, value in [fact.split() for fact in self.solution]
}

为:

point_to_value = {
    point: value
    for point, value in [
        fact.split() for fact in self.solution if not is_ext_var(fact)
    ]
}

当然有更好的方法可以做到这一点,这只是一个补丁。 :)

使用上述想法重新实现您的 from_dnf,我们得到:

def from_dnf(groups) -> "cnf":
    "Convert from or-of-ands to and-of-ors, equisatisfiably"
    cnf = []

    extension_vars = []
    for group in groups:
        extension_var = ext_var()
        neg_extension_var = neg(extension_var)

        imply_ext_clause = []
        for literal in group:
            imply_ext_clause.append(neg(literal))
            cnf.append((neg_extension_var, literal))

        imply_ext_clause.append(extension_var)
        cnf.append(tuple(imply_ext_clause))

        extension_vars.append(extension_var)

    cnf.append(tuple(extension_vars))
    return cnf

每个组都有一个扩展变量。组中的每个文字将其否定添加到 (k+1) 大小的蕴涵子句中,并由扩展隐含。处理文字后,扩展变量完成剩余的含义并将其自身添加到新扩展变量的列表中。最后,这些扩展变量中至少有一个必须为 true。

仅此更改就可以让我立即解决这个 5x5 难题:

self.visible_from_top = [3, 2, 1, 4, 2]
self.visible_from_bottom = [2, 2, 4, 1, 2]
self.visible_from_left = [3, 2, 3, 1, 3]
self.visible_from_right = [2, 2, 1, 3, 2]
self.given_numbers = {}

我还添加了一些计时输出:

@property
def solution(self):
    if self._solution is None:
        start_time = time.perf_counter()

        cnf = self.cnf
        cnf_time = time.perf_counter()
        print(f"CNF: {cnf_time - start_time}s")

        self._solution = solve_one(cnf)
        end_time = time.perf_counter()
        print(f"Solve: {end_time - cnf_time}s")
    return self._solution

5x5 难题给了我:

CNF: 0.00565183162689209s
Solve: 0.005589433014392853s

但是,在枚举可行的塔高度排列时,我们仍然存在令人讨厌的 k! 增长。

我生成了一个 9x9 的谜题(网站允许的最大),对应于:

self.visible_from_top = [3, 3, 3, 3, 1, 4, 2, 4, 2]
self.visible_from_bottom = [3, 1, 4, 2, 5, 3, 3, 2, 3]
self.visible_from_left = [3, 3, 1, 2, 4, 5, 2, 3, 2]
self.visible_from_right = [3, 1, 7, 4, 3, 3, 2, 2, 4]
self.given_numbers = {
    "AB": 5,
    "AD": 4,
    "BD": 3,
    "BE": 2,
    "CD": 7,
    "CF": 5,
    "CG": 1,
    "DB": 1,
    "DH": 7,
    "EA": 4,
    "EI": 2,
    "FA": 2,
    "FE": 8,
    "GG": 7,
    "GI": 6,
    "HA": 3,
    "HF": 2,
    "HH": 1,
    "IG": 6,
}

这给了我:

CNF: 28.505195066332817s
Solve: 40.48229135945439s

我们应该花更多的时间解决问题,更少的时间生成,但接近一半的时间都在生成。

在我看来,在 CNF-SAT 翻译中使用 DNF 通常是错误方法的标志。求解器比我们更善于探索和了解解空间——花费阶乘时间进行预探索实际上比求解器的指数最坏情况更糟糕。

“回退”到 DNF 是可以理解的,因为程序员自然会想到“编写一个能够给出解决方案的算法”。但是,当您将其编码到问题中时,求解器的真正好处就会显现出来。让求解器推理解决方案不可行的条件。为此,我们需要从电路的角度进行思考。幸运的是,我们还知道如何快速将电路转变为 CNF。

†​​我说“经常”; 预先计算一些解决方案空间会很有帮助。


如果您的 DNF 很小并且可以快速生成(例如单个电路门),或者如果将其编码为电路非常复杂,那么 实际上已经做了一些!例如,我们需要一个电路来计算某个数字在一个范围(行或列)中出现的次数,并断言该数字恰好是一。然后对于每个跨度和每个数字,我们将发出这个电路。这样,如果尺寸为 3 的塔连续出现两次,则该行 3 的计数器将发出“2”,并且我们关于其为“1”的断言将不会得到支持。

您的 one_of 约束是 一种可能的实现。您使用“明显”的成对编码:对于跨度中的每个位置,如果 N 出现在该位置,那么它不会出现在任何其他位置。这实际上是一个非常好的编码,因为它几乎完全由二进制子句组成,并且 SAT 求解器喜欢二进制子句(它们使用的内存显着减少并且经常传播)。但对于真正需要计数的大量事物,这种 O(n^2) 缩放可能会成为一个问题。

您可以想象一种替代方法,您可以对加法器电路进行字面编码:每个位置都是电路的输入位,电路产生 n 位输出,告诉您最终的总和(上面的论文值得一读!)。然后,您可以使用强制特定输出位的单位子句断言该总和恰好为一。

仅对电路进行编码以强制其某些输出为恒定值似乎是多余的。然而,这更容易推理,并且现代求解器意识到编码可以做到这一点并对其进行优化。它们的处理过程比初始编码过程更复杂。使用求解器的“艺术”在于了解和测试这些替代编码何时比其他编码更有效。

请注意,exactly_k_ofat_least_k_of 以及 at_most_k_of。您已在 Q== 实现中注意到这一点。实现 at_least_1_of 很简单,是一个子句; at_most_1_of 非常常见,通常简称为 AMO。我鼓励您尝试以本文中讨论的其他一些方式实现 <> (甚至可能根据输入大小选择使用哪种方式)以获得感受它。


将我们的注意力转回 k! 可见性约束,我们需要的是一个电路,它告诉我们从某个方向可以看到多少座塔,然后我们可以断言它是一个特定值。

停下来想想如何做到这一点,这并不容易!

与各种 one_of 方法类似,我们可以使用“纯”电路进行计数或使用更简单的电路但规模更差的成对方法。我在这个答案的最底部(‡)附上了纯电路方法的草图。现在我们将使用成对方法。

主要观察是,在不可见的塔中,我们不关心它们的排列。考虑一下:

3 -> 1 5 _ _ _ 9 _ _ _
     A B C D E F G H I

只要 CDE 组包含 234,我们就会从左侧看到 3 座塔,同样如果 GHI 组包含包含 678。但它们在组中出现的顺序对可见的塔没有任何影响。

我们不会计算哪些塔是可见的,而是声明哪些塔是可见的并遵循它们的含义。我们将填写此函数:

def visible_in_span(points: Collection[str], desired: int) -> "cnf":
    """Assert desired visible towers in span. Wlog, visibility is from index 0."""
    points = list(points)
    n = len(points)
    assert desired <= n

    cnf = []

    # ...

    return cnf

假设固定跨度和观察方向:每个位置将有 k 个关联变量,Av1Avk 表示“这是第 k 个可见塔” ”。我们还有 Av ≡ (Av1 ∨ Av2 ∨ ⋯ ∨ Avk) 意思是“A 有一座可见的塔”。

在上面的示例中,Av1Bv2Fv3 均为 true。排放有一些明显的含义。在某个位置,最多其中一个是正确的(你不能同时是第一个和第二个可见的塔)——但不完全是一个,因为拥有一个不可见的塔是完全可以的。另一个是,如果一个位置是第 k 个可见塔,则没有其他位置也是第 k 个可见塔。

到目前为止,我们可以添加以下内容:

is_kth_visible_tower_at = {}
is_kth_visible_tower_vars = collections.defaultdict(list)
is_visible_tower_at = {}
for point in points:
    is_visible_tower_vars = []
    for k in range(1, n + 1):
        # Xvk
        is_kth_visible_tower_var = ext_var()

        is_kth_visible_tower_at[(point, k)] = is_kth_visible_tower_var
        is_kth_visible_tower_vars[k].append(is_kth_visible_tower_var)
        is_visible_tower_vars.append(is_kth_visible_tower_var)

    # Xv
    is_visible_tower_at_var = ext_var()
    # Xv → (Xv1 ∨ Xv2 ∨ ⋯)
    cnf.append(tuple([neg(is_visible_tower_at_var)] + is_visible_tower_vars))
    # (Xv1 ∨ Xv2 ∨ ⋯) → Xv
    for is_visible_tower_var in is_visible_tower_vars:
        cnf.append((neg(is_visible_tower_var), is_visible_tower_at_var))

    is_visible_tower_at[point] = is_visible_tower_at_var

    # At most one visible tower here.
    cnf += Q(is_visible_tower_vars) <= 1

# At most one kth visible tower anywhere.
for k in range(1, n + 1):
    cnf += Q(is_kth_visible_tower_vars[k]) <= 1

接下来,我们需要在可见塔之间进行排序,以便第 k + 1 个可见塔位于第 k 个可见塔之后。这是通过第 k+1 个可见塔迫使至少一个先前位置成为第 k 个可见塔来实现的。例如,Dv3 → (Av2 ∨ Bv2 ∨ Cv2)Cv2 → (Av1 ∨ Bv1)。我们知道 Av1 始终为真,这提供了基本情况。 (如果我们进入需要 B 成为第三个可见塔的情况,则需要 A 成为第二个可见塔,这与 Av1 相矛盾。)

# Towers are ordered.
for index, point in enumerate(points):
    if index == 0:
        cnf += basic_fact(is_kth_visible_tower_at[(point, 1)])
        continue

    for k in range(1, n + 1):
        # Xvk → ⋯
        implication = [neg(is_kth_visible_tower_at[(point, k)])]

        j = k - 1
        if j > 0:
            for index_j, point_j in enumerate(points):
                if index_j == index:
                    break

                # ⋯ ∨ Wxj ∨ ⋯
                implication.append(is_kth_visible_tower_at[(point_j, j)])

        cnf.append(tuple(implication))

到目前为止一切顺利,但我们还没有关联塔的高度到可见度。上面将允许 9 8 7 作为解决方案,调用 9 第一个可见塔,8 第二个,以及 7第三个。为了解决这个问题,我们希望塔的放置能够防止较小的塔也可见。

每个位置将再次收到一组缩写,指示它是否在特定高度以下被遮挡,称为 Ao1Ao2 等。这将为我们提供一个“网格”,让事情变得更简单。首先,较高的塔被遮挡意味着同一位置的下一个最高的塔也被遮挡,因此 Ao3 → Ao2Ao2 → Ao1。第二个是,如果一座塔在一个位置被遮挡,那么它在以后的所有位置也会被遮挡。这是 Ao3 → Bo3Bo3 → Co3 等等。

is_height_obscured_at = {}
is_height_obscured_previous = [None] * n
for point in points:
    is_obscured_previous = None
    for k in range(1, n + 1):
        # Xok
        is_height_obscured_var = ext_var()

        # Wok → Xok
        is_k_obscured_previous = is_height_obscured_previous[k - 1]
        if is_k_obscured_previous is not None:
            cnf.append((neg(is_k_obscured_previous), is_height_obscured_var))

        # Xok → Xo(k-1)
        if is_obscured_previous is not None:
            cnf.append((neg(is_height_obscured_var), is_obscured_previous))

        is_height_obscured_at[(point, k)] = is_height_obscured_var
        is_height_obscured_previous[k - 1] = is_height_obscured_var
        is_obscured_previous = is_height_obscured_var

由此很容易看出,例如 Bo4 意味着其余高度等于或小于 4 的塔都被遮挡了。我们现在可以轻松地将塔的放置与隐蔽性联系起来:A5 → Bo4

# A placed tower obscures smaller later towers.
for index, point in enumerate(points):
    if index + 1 == len(points):
        break

    next_point = points[index + 1]
    for k in range(2, n + 1):
        j = k - 1

        # Xk → Yo(k-1)
        cnf.append((neg(comb(point, k)), is_height_obscured_at[(next_point, j)]))

最后,我们需要将模糊性与可见性联系起来。我们需要最后一组缩写,说明在某个位置可以看到特定的塔高度。冒着容易打错的风险,我们将这个高度 h 称为 Ahv,因此 Ahv ≡ (Ah ∧ Av)。一个具体的例子是C3v == (C3 ∧ Cv):当且仅当在C处有一座可见的塔时,高度为3的塔在C处可见,并且该塔是高度为3的塔。

is_height_visible_at = {}
for point in points:
    for k in range(1, n + 1):
        # Xhv
        height_visible_at_var = ext_var()

        # Xhv ≡ (Xh ∧ Xv)
        cnf.append((neg(height_visible_at_var), comb(point, k)))
        cnf.append((neg(height_visible_at_var), is_visible_tower_at[point]))
        cnf.append(
            (
                neg(comb(point, k)),
                neg(is_visible_tower_at[point]),
                height_visible_at_var,
            )
        )

        is_height_visible_at[(point, k)] = height_visible_at_var

这使我们能够得出对塔放置的最终影响。如果高度为 h 的塔被遮挡,则它是不可见的:Bo4 → ØB4v。这不是等价的,我们不能对待Bo4 ≠ B4v;也许 ØB4v 成立,因为 B4 根本没有放置在那里(但它是可见的!)。

for point in points:
    for k in range(1, n + 1):
        # Xok → ¬Xkv
        cnf.append(
            (
                neg(is_height_obscured_at[(point, k)]),
                neg(is_height_visible_at[(point, k)]),
            )
        )

为了将其与特定谜题的可见性值联系起来,我们只需要禁止太多的可见塔,并确保所需的计数至少可见一个(因此恰好一次):

# At least one of the towers is the desired kth visible.
cnf.append(tuple(is_kth_visible_tower_vars[desired]))

# None of the towers can be visible above the desired kth.
if desired < n:
    for is_kth_visible_tower_var in is_kth_visible_tower_vars[desired + 1]:
        cnf += basic_fact(neg(is_kth_visible_tower_var))

return cnf

我们只需要阻止第一层不需要的第 k 个可见塔。由于第 k + 1 层意味着存在第 k 层可见塔,因此它也被排除。 (等等。)

最后,我们将其连接到 CNF 构建器中:

# Set visible from left
if self.visible_from_left:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_left[index]
        if not target_visible:
            continue

        cnf += visible_in_span(row, target_visible)

# Set visible from right
if self.visible_from_right:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_right[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(row), target_visible)

# Set visible from top
if self.visible_from_top:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_top[index]
        if not target_visible:
            continue

        cnf += visible_in_span(col, target_visible)

# Set visible from bottom
if self.visible_from_bottom:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_bottom[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(col), target_visible)

上面的内容让我更快地得到了 9x9 解决方案:

CNF: 0.028973951935768127s
Solve: 0.07169117406010628s

大约快了 685 倍,并且求解器正在完成更多的整体工作。又快又脏还不错!

有很多方法可以清理这个问题。例如,为了可读性,我们看到的每个地方 cnf.append((neg(a), b)) 都可以是 cnf += implys(a, b) 。我们可以避免分配无意义的大 kth 可见变量,等等。

这还没有经过充分测试;我可能错过了一些含义或规则。希望此时很容易修复。


我想谈的最后一件事是 SAT 的适用性。也许现在很痛苦地清楚了,SAT 求解器并不擅长计数和算术。你必须降低到一个电路,从求解过程中隐藏更高级别的语义。

其他方法可以让你自然地表达算术、区间、集合等。答案集编程 (ASP) 就是一个例子,SMT 求解器是另一个例子。对于小问题,SAT 没问题,但对于困难问题,这些更高层次的方法可以大大简化问题。

他们中的每一个实际上可能在内部决定通过 SAT 解决(特别是 SMT)进行推理,但他们将在对问题的一些更高层次的理解的背景下这样做。


‡ 这是计数塔的纯电路方法。

它是否比pairwise更好取决于所计算的塔的数量;也许常数因子太高而没有用处,或者即使在小尺寸下它也非常有用。老实说,我不知道——我以前编码过巨大的电路并且让它们工作得很好。需要实验才能知道。

我将把 Ah 称为位置 A 中塔的整数高度。也就是说,而不是 A1A2< 的 one-hot 编码/code> 或 ... 或 A9 我们将 Ah0Ah1、... 和 Ahn 作为低位通过 n 位整数的高位(统称为)。对于 9x9 的限制,4 位就足够了。我们还将有 BhCh 等。

您可以使用 A1 ≡ (ØAh3 ∧ ØAh2 ∧ ØAh1 ∧ Ah0)A2 ≡ (ØAh3 ∧ ØAh2 ∧ Ah1 ∧ ØAh0) 连接这两个表示和 A3 ≡ (ØAh3 ∧ ØAh2 ∧ Ah1 ∧ Ah0) 等等。当且仅当设置了 A3 时,我们才有 Ah = 3。 (我们不需要添加一次只能有一个 Ah 值的约束,因为与每个值关联的 one-hot 变量已经做到了这一点。)

有了一个整数在手,它可能更容易了解如何计算可见性。我们可以将每个位置与最大可见塔高度相关联,命名为 AmBm 等;显然,第一座塔总是可见的,并且是最高的,所以Am == Ah。同样,这实际上是一个 n 位值 Am0Amn

当且仅当塔的值大于先验的最高值时,塔才可见。我们将使用 AvBv 等来跟踪可见性。这可以通过数字比较器来完成;使得 Bv ≠ Bh >上午。 (Av 是一个基本情况,并且始终为真。)

这也让我们可以填写其余的最大值。 Bm = Bv ? Bh:Am,等等。条件/if-then-else/ite 是数字多路复用器。对于简单的 2 对 1,这很简单:Bv ? Bh : Am(Bv ∧ Bh) ∨ (ØBv ∧ Am),实际上是 (Bv ∧ Bhi) ∨ (ØBv ∧ Ami)对于每个i ∈ 0..n

然后,我们将有一堆单个输入 AvIv 馈入加法器电路,告诉我们这些输入中有多少是真实的(即有多少座塔)是可见的)。这将是另一个 n 位值;然后我们使用单元子句来断言它正是例如 3,如果特定的谜题需要 3 个可见的塔。

我们为每个方向的每个跨度生成相同的电路。这将是一些多项式大小的规则编码,添加许多扩展变量和许多子句。求解器可以了解到某个塔的放置是不可行的,不是因为我们这么说的,而是因为它意味着一些不可接受的中间可见性。 “应该有 4 个可见,而 2 个已经可见,所以剩下的就是……”。

First, it's good to note the difference between equivalence and equisatisfiability. In general, converting an arbitrary boolean formula (say, something in DNF) to CNF can result in a exponential blow-up in size.

This blow-up is the issue with your from_dnf approach: whenever you handle another product term, each of the literals in that product demands a new copy of the current cnf clause set (to which it will add itself in every clause). If you have n product terms of size k, the growth is O(k^n).

In your case n is actually a function of k!. What's kept as a product term is filtered to those satisfying the view constraint, but overall the runtime of your program is roughly in the region of O(k^f(k!)). Even if f grows logarithmically, this is still O(k^(k lg k)) and not quite ideal!

Because you're asking "is this satisfiable?", you don't need an equivalent formula but merely an equisatisfiable one. This is some new formula that is satisfiable if and only if the original is, but which might not be satisfied by the same assignments.

For example, (a ∨ b) and (a ∨ c) ∧ (¬b) are each obviously satisfiable, so they are equisatisfiable. But setting b true satisfies the first and falsifies the second, so they are not equivalent. Furthermore the first doesn't even have c as a variable, again making it not equivalent to the second.

This relaxation is enough to replace this exponential blow-up with a linear-sized translation instead.


The critical idea is the use of extension variables. These are fresh variables (i.e., not already present in the formula) that allow us to abbreviate expressions, so we don't end up making multiple copies of them in the translation. Since the new variable is not present in the original, we'll no longer have an equivalent formula; but because the variable will be true if and only if the expression is, it will be equisatisfiable.

If we wanted to use x as an abbreviation of y, we'd state x ≡ y. This is the same as x → y and y → x, which is the same as (¬x ∨ y) ∧ (¬y ∨ x), which is already in CNF.

Consider the abbreviation for a product term: x ≡ (a ∧ b). This is x → (a ∧ b) and (a ∧ b) → x, which works out to be three clauses: (¬x ∨ a) ∧ (¬x ∨ b) ∧ (¬a ∨ ¬b ∨ x). In general, abbreviating a product term of k literals with x will produce k binary clauses expressing that x implies each of them, and one (k+1)-clause expressing that all together they imply x. This is linear in k.

To really see why this helps, try converting (a ∧ b ∧ c) ∨ (d ∧ e ∧ f) ∨ (g ∧ h ∧ i) to an equivalent CNF with and without an extension variable for the first product term. Of course, we won't just stop with one term: if we abbreviate each term then the result is precisely a single CNF clause: (x ∨ y ∨ z) where these each abbreviate a single product term. This is a lot smaller!

This approach can be used to turn any circuit into an equisatisfiable formula, linear in size and in CNF. This is called a Tseitin transformation. Your DNF formula is simply a circuit composed of a bunch of arbitrary fan-in AND gates, all feeding into a single arbitrary fan-in OR gate.

Best of all, although this formula is not equivalent due to additional variables, we can recover an assignment for the original formula by simply dropping the extension variables. It is sort of a 'best case' equisatisfiable formula, being a strict superset of the original.


To patch this into your code, I added:

# Uses pseudo-namespacing to avoid collisions.
_EXT_SUFFIX = "___"
_NEXT_EXT_INDEX = 0


def is_ext_var(element) -> bool:
    return element.endswith(_EXT_SUFFIX)


def ext_var() -> str:
    global _NEXT_EXT_INDEX
    ext_index = _NEXT_EXT_INDEX
    _NEXT_EXT_INDEX += 1

    return intern(f"{ext_index}{_EXT_SUFFIX}")

This lets us pull a fresh named variable out of thin air. Since these extension variable names don't have meaningful semantics to your solution display function, I changed:

point_to_value = {
    point: value for point, value in [fact.split() for fact in self.solution]
}

into:

point_to_value = {
    point: value
    for point, value in [
        fact.split() for fact in self.solution if not is_ext_var(fact)
    ]
}

There are certainly better ways to do this, this is just a patch. :)

Reimplementing your from_dnf with the above ideas, we get:

def from_dnf(groups) -> "cnf":
    "Convert from or-of-ands to and-of-ors, equisatisfiably"
    cnf = []

    extension_vars = []
    for group in groups:
        extension_var = ext_var()
        neg_extension_var = neg(extension_var)

        imply_ext_clause = []
        for literal in group:
            imply_ext_clause.append(neg(literal))
            cnf.append((neg_extension_var, literal))

        imply_ext_clause.append(extension_var)
        cnf.append(tuple(imply_ext_clause))

        extension_vars.append(extension_var)

    cnf.append(tuple(extension_vars))
    return cnf

Each group gets an extension variable. Each literal in the group adds its negation into the (k+1)-sized implication clause, and becomes implied by the extension. After the literals are handled, the extension variable finalizes the remaining implication and adds itself to the list of new extension variables. Finally, at least one of these extension variables must be true.

This change alone lets me solve this 5x5 puzzle ~instantly:

self.visible_from_top = [3, 2, 1, 4, 2]
self.visible_from_bottom = [2, 2, 4, 1, 2]
self.visible_from_left = [3, 2, 3, 1, 3]
self.visible_from_right = [2, 2, 1, 3, 2]
self.given_numbers = {}

I added some timing output as well:

@property
def solution(self):
    if self._solution is None:
        start_time = time.perf_counter()

        cnf = self.cnf
        cnf_time = time.perf_counter()
        print(f"CNF: {cnf_time - start_time}s")

        self._solution = solve_one(cnf)
        end_time = time.perf_counter()
        print(f"Solve: {end_time - cnf_time}s")
    return self._solution

The 5x5 puzzle gives me:

CNF: 0.00565183162689209s
Solve: 0.005589433014392853s

However, we still have that pesky k! growth when enumerating viable tower height permutations.

I generated a 9x9 puzzle (the largest the site permits), which corresponds to:

self.visible_from_top = [3, 3, 3, 3, 1, 4, 2, 4, 2]
self.visible_from_bottom = [3, 1, 4, 2, 5, 3, 3, 2, 3]
self.visible_from_left = [3, 3, 1, 2, 4, 5, 2, 3, 2]
self.visible_from_right = [3, 1, 7, 4, 3, 3, 2, 2, 4]
self.given_numbers = {
    "AB": 5,
    "AD": 4,
    "BD": 3,
    "BE": 2,
    "CD": 7,
    "CF": 5,
    "CG": 1,
    "DB": 1,
    "DH": 7,
    "EA": 4,
    "EI": 2,
    "FA": 2,
    "FE": 8,
    "GG": 7,
    "GI": 6,
    "HA": 3,
    "HF": 2,
    "HH": 1,
    "IG": 6,
}

This gives me:

CNF: 28.505195066332817s
Solve: 40.48229135945439s

We should spend more time solving and less time generating, but close to half the time is generating.

In my opinion, using DNF in a CNF-SAT translation is often† a sign of the wrong approach. Solvers are way better at exploring and learning about the solution space than we are — spending factorial amount of time pre-exploring is actually worse than the solver's exponential worse case.

It's understandable to 'fall back' to DNF, because programmers naturally think in terms of "write an algorithm that emits solutions". But the real benefit of solvers kicks in when you encode this in the problem. Let the solver reason about conditions in which solutions become infeasible. To do this, we want to think in terms of circuits. Lucky for us, we also know how to turn a circuit into CNF quickly.

†I said "often"; if your DNF is small and quick to produce (like a single circuit gate), or if encoding it to a circuit is prohibitively complicated, then it can be helpful to pre-compute some of the solution space.


You've actually already done some of this! For example, we will need a circuit that counts how many times a certain number appears in a span (row or column), and an assertion that this number is exactly one. Then for each span and for each number, we'll emit this circuit. That way if a tower of size e.g. 3 appears twice in a row, the counter for that row for 3 will emit '2' and our assertion that it be '1' will not be upheld.

Your one_of constraint is one possible implementation of this. Yours uses the 'obvious' pairwise encoding: for each location in the span, if N is present at that location then it is not present in any other location. This is actually quite a good encoding because it's comprised almost entirely of binary clauses, and SAT solvers love binary clauses (they use significantly less memory and propagate often). But for really large sets of things to count, this O(n^2) scaling can become an issue.

You can imagine an alternative approach where you literally encode an adder circuit: each location is an input bit to the circuit, and the circuit produces n bits of output telling you the final sum (the paper above is a good read!). You then assert this sum is exactly one using unit clauses that force specific output bits.

It may seem redundant to encode a circuit only to force some of its outputs to be a constant value. However, this is much easier to reason about and modern solvers are aware that encodings do this and optimize for it. They perform significantly more sophisticated in-processing than the initial encoding process could reasonably do. The 'art' of using solvers is in knowing and testing when these alternative encodings work better than others.

Note that exactly_k_of is at_least_k_of along with at_most_k_of. You've noted this in your Q class == implementation. Implementing at_least_1_of is trivial, being one clause; at_most_1_of is so common it's often just called AMO. I encourage you to try implementing < and > in some of the other ways discuss in the paper (perhaps even choosing which to use based on input size) to get a feel for it.


Turning our attention back to the k! visibility constraints, what we need is a circuit that tells us how many towers are visible from a certain direction, which we can then assert be a specific value.

Stop and think about how this could be done, it's not easy!

Analogous to the various one_of approaches, we can go with a 'pure' circuit for counting or use a simpler but worse-scaling pairwise approach. I have attached the sketch of the pure circuit approach at the very bottom (‡) of this answer. For now we will use the pairwise method.

The main observation to make is that among non-visible towers, we don't care about their permutations. Consider:

3 -> 1 5 _ _ _ 9 _ _ _
     A B C D E F G H I

We see 3 towers from the left as long as the CDE group contains 2, 3, and 4, and likewise if the GHI group contains 6, 7, and 8. But the order in which they appear in the group has no implication on the visible towers.

Rather than compute which towers are visible, we will declare which towers are visible and follow their implications. We'll be filling in this function:

def visible_in_span(points: Collection[str], desired: int) -> "cnf":
    """Assert desired visible towers in span. Wlog, visibility is from index 0."""
    points = list(points)
    n = len(points)
    assert desired <= n

    cnf = []

    # ...

    return cnf

Assume a fixed span and viewing direction: each location will have k associated variables, Av1 through Avk stating "this is the kth visible tower". We will also have Av ≡ (Av1 ∨ Av2 ∨ ⋯ ∨ Avk) meaning "A has a visible tower".

In the above example, Av1, Bv2, and Fv3 are all true. There are some obvious implications to emit. At a location, at most one of these is true (you can't be both the first and second visible tower) — but not exactly one, since it's perfectly fine to have a non-visible tower. Another is that if a location is the kth visible tower, then no other location is also the kth visible tower.

We can add this so far:

is_kth_visible_tower_at = {}
is_kth_visible_tower_vars = collections.defaultdict(list)
is_visible_tower_at = {}
for point in points:
    is_visible_tower_vars = []
    for k in range(1, n + 1):
        # Xvk
        is_kth_visible_tower_var = ext_var()

        is_kth_visible_tower_at[(point, k)] = is_kth_visible_tower_var
        is_kth_visible_tower_vars[k].append(is_kth_visible_tower_var)
        is_visible_tower_vars.append(is_kth_visible_tower_var)

    # Xv
    is_visible_tower_at_var = ext_var()
    # Xv → (Xv1 ∨ Xv2 ∨ ⋯)
    cnf.append(tuple([neg(is_visible_tower_at_var)] + is_visible_tower_vars))
    # (Xv1 ∨ Xv2 ∨ ⋯) → Xv
    for is_visible_tower_var in is_visible_tower_vars:
        cnf.append((neg(is_visible_tower_var), is_visible_tower_at_var))

    is_visible_tower_at[point] = is_visible_tower_at_var

    # At most one visible tower here.
    cnf += Q(is_visible_tower_vars) <= 1

# At most one kth visible tower anywhere.
for k in range(1, n + 1):
    cnf += Q(is_kth_visible_tower_vars[k]) <= 1

Next we need ordering among visible towers, so that the kth + 1 visible tower comes after the kth visible tower. This is accomplished by the kth + 1 visible tower forcing at least one of the prior locations to be the kth visible tower. E.g., Dv3 → (Av2 ∨ Bv2 ∨ Cv2) and Cv2 → (Av1 ∨ Bv1). We know Av1 is always true which provides the base case. (If we enter a situation like needing B to be the third visible tower, that will require A be the second visible tower which contradicts Av1.)

# Towers are ordered.
for index, point in enumerate(points):
    if index == 0:
        cnf += basic_fact(is_kth_visible_tower_at[(point, 1)])
        continue

    for k in range(1, n + 1):
        # Xvk → ⋯
        implication = [neg(is_kth_visible_tower_at[(point, k)])]

        j = k - 1
        if j > 0:
            for index_j, point_j in enumerate(points):
                if index_j == index:
                    break

                # ⋯ ∨ Wxj ∨ ⋯
                implication.append(is_kth_visible_tower_at[(point_j, j)])

        cnf.append(tuple(implication))

So far so good, but we haven't related tower height to visibility. The above would allow 9 8 7 as a solution, calling 9 the first visible tower, 8 the second, and 7 the third. To solve this we want a tower placement to prohibit a smaller tower from also being visible.

Each location will again receive a set of abbreviations indicating if it obscured below a certain height, called Ao1, Ao2, and so on. This will give us a 'grid' of implications that keep things simpler. The first is that a higher tower being obscured implies the next highest tower at the same location is also obscured, so that Ao3 → Ao2 and Ao2 → Ao1. The second is that if a tower is obscured at one location, it is also obscured at all later locations. This is Ao3 → Bo3 and Bo3 → Co3 and so on.

is_height_obscured_at = {}
is_height_obscured_previous = [None] * n
for point in points:
    is_obscured_previous = None
    for k in range(1, n + 1):
        # Xok
        is_height_obscured_var = ext_var()

        # Wok → Xok
        is_k_obscured_previous = is_height_obscured_previous[k - 1]
        if is_k_obscured_previous is not None:
            cnf.append((neg(is_k_obscured_previous), is_height_obscured_var))

        # Xok → Xo(k-1)
        if is_obscured_previous is not None:
            cnf.append((neg(is_height_obscured_var), is_obscured_previous))

        is_height_obscured_at[(point, k)] = is_height_obscured_var
        is_height_obscured_previous[k - 1] = is_height_obscured_var
        is_obscured_previous = is_height_obscured_var

From this it's easy to see that stating e.g. Bo4 implies the remaining towers equal or less than 4 in height are all obscured. We can now easily relate tower placement to obscurity: A5 → Bo4.

# A placed tower obscures smaller later towers.
for index, point in enumerate(points):
    if index + 1 == len(points):
        break

    next_point = points[index + 1]
    for k in range(2, n + 1):
        j = k - 1

        # Xk → Yo(k-1)
        cnf.append((neg(comb(point, k)), is_height_obscured_at[(next_point, j)]))

Last, we need to relate obscurity to visibility. We'll need one final set of abbreviations, stating that a specific tower height is visible at a location. At the risk of making typos easy, we'll call this Ahv for some height h, so that Ahv ≡ (Ah ∧ Av). A concrete example would be C3v ≡ (C3 ∧ Cv): a tower of height 3 is visible at C if and only if there is a tower visible at C, and that tower is the height 3 tower.

is_height_visible_at = {}
for point in points:
    for k in range(1, n + 1):
        # Xhv
        height_visible_at_var = ext_var()

        # Xhv ≡ (Xh ∧ Xv)
        cnf.append((neg(height_visible_at_var), comb(point, k)))
        cnf.append((neg(height_visible_at_var), is_visible_tower_at[point]))
        cnf.append(
            (
                neg(comb(point, k)),
                neg(is_visible_tower_at[point]),
                height_visible_at_var,
            )
        )

        is_height_visible_at[(point, k)] = height_visible_at_var

This allows us to emit the final implications on tower placement. If a tower of height h is obscured, it is not visible: Bo4 → ¬B4v. This is not an equivalence, and we cannot treat Bo4 ≡ ¬B4v; maybe ¬B4v holds because B4 simply isn't placed there (but would be visible it it were!).

for point in points:
    for k in range(1, n + 1):
        # Xok → ¬Xkv
        cnf.append(
            (
                neg(is_height_obscured_at[(point, k)]),
                neg(is_height_visible_at[(point, k)]),
            )
        )

To relate this to the puzzle-specific visibility value, we just need to prohibit too many visible towers and ensure the desired count is visible at least one (and therefore exactly once):

# At least one of the towers is the desired kth visible.
cnf.append(tuple(is_kth_visible_tower_vars[desired]))

# None of the towers can be visible above the desired kth.
if desired < n:
    for is_kth_visible_tower_var in is_kth_visible_tower_vars[desired + 1]:
        cnf += basic_fact(neg(is_kth_visible_tower_var))

return cnf

We only need to block the first level of undesirable kth visible towers. Since the kth + 1 level will imply the existence of a kth level visible tower, it too is ruled out. (And so on.)

Finally, we hook this into the CNF builder:

# Set visible from left
if self.visible_from_left:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_left[index]
        if not target_visible:
            continue

        cnf += visible_in_span(row, target_visible)

# Set visible from right
if self.visible_from_right:
    for index, row in enumerate(self.rows):
        target_visible = self.visible_from_right[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(row), target_visible)

# Set visible from top
if self.visible_from_top:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_top[index]
        if not target_visible:
            continue

        cnf += visible_in_span(col, target_visible)

# Set visible from bottom
if self.visible_from_bottom:
    for index, col in enumerate(self.cols):
        target_visible = self.visible_from_bottom[index]
        if not target_visible:
            continue

        cnf += visible_in_span(reversed(col), target_visible)

The above gives me the 9x9 solution much more quickly:

CNF: 0.028973951935768127s
Solve: 0.07169117406010628s

About 685x faster, and the solver is doing more of the overall work. Not bad for quick and dirty!

There are many ways to clean this up. E.g., every place we see cnf.append((neg(a), b)) could be cnf += implies(a, b) instead for readability. We could avoid allocation of pointlessly large kth visible variables, and so on.

This is not well-tested; I may have missed some implications or a rule. Hopefully it's easy to fix at this point.


The last thing I want to touch on is the applicability of SAT. Perhaps painfully clear now, SAT solvers are not exactly great at counting and arithmetic. You have to lower to a circuit, hiding the higher-level semantics from the solving process.

Other approaches will let you express arithmetic, intervals, sets, and so on naturally. Answer set programming (ASP) is one example of this, SMT solvers are another. For small problems SAT is fine, but for difficult problems these higher-level approaches can greatly simplify the problem.

Each of those may actually internally decide to reason via SAT-solving (SMT in particular), but they will be doing so in the context of some higher-level understanding of the problem.


‡ This is the pure circuit approach to counting towers.

Whether or not it is better than pairwise will depend on the number of towers being counted; maybe the constant factors are so high it's never useful, or maybe it's quite useful even at low sizes. I honestly have no idea — I've encoded huge circuits before and had them work great. It requires experimentation to know.

I'm going to call Ah the integer height of the tower in location A. That is, rather than a one-hot encoding of either A1 or A2 or … or A9 we'll have Ah0, Ah1, …, and Ahn as the low through high bits of an n-bit integer (collectively Ah). For a limit of 9x9, 4 bits suffice. We'll also have Bh, Ch, and so on.

You can join the two representations using A1 ≡ (¬Ah3 ∧ ¬Ah2 ∧ ¬Ah1 ∧ Ah0) and A2 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ ¬Ah0) and A3 ≡ (¬Ah3 ∧ ¬Ah2 ∧ Ah1 ∧ Ah0) and so on. We have Ah = 3 if and only if A3 is set. (We don't need to add constraints that only one value of Ah is possible at a time, since the one-hot variables associated to each do this already.)

With an integer in hand, it might be easier to see how to compute visibility. We can associate each location with a maximum seen tower height, named Am, Bm, and so on; obviously the first tower is always visible and the highest seen, so Am ≡ Ah. Again, this is actually an n-bit value Am0 through Amn.

A tower is visible if and only if it's value is larger than the prior's highest seen. We'll track visibility with Av, Bv, and so on. This can be done with a digital comparator; so that Bv ≡ Bh > Am. (Av is a base case, and is simply always true.)

This lets us fill in the rest of the max values as well. Bm ≡ Bv ? Bh : Am, and so on. A conditional/if-then-else/ite is a digital multiplexer. For a simple 2-to-1 this is straightforward: Bv ? Bh : Am is (Bv ∧ Bh) ∨ (¬Bv ∧ Am), which is really (Bv ∧ Bhi) ∨ (¬Bv ∧ Ami) for each i ∈ 0..n.

Then, we'll have a bunch of single inputs Av through Iv that feed into an adder circuit, telling us how many of these inputs are true (i.e., how many towers are visible). This will be yet another n-bit value; then we use unit clauses to assert that it is exactly e.g. 3, if the particular puzzle demands 3 visible towers.

We generate this same circuit for every span in every direction. This will be some polynomial-sized encoding of the rules, adding many extension variables and many clauses. A solver can learn a certain tower placement isn't viable not because we said so, but because it implies some unacceptable intermediate visibility. "There should be 4 visible, and 2 are already visible, so that leaves me with...".

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文