Enable Javascript in your browser and then refresh this page, for a much enhanced experience.
DFS solution in Clear category for Magic Square by pokosasa
from copy import deepcopy
from itertools import combinations, permutations
def checkio(data):
N = len(data)**2
S = N*(N+1)//2//len(data)
placed = set.union(*(set(row)-{0} for row in data))
rest = set(range(1, N+1))-placed
def check(data):
for row in data:
if 0 in row and sum(row) >= S:
return False
if 0 not in row and sum(row) != S:
return False
for col in zip(*data):
if 0 in col and sum(col) >= S:
return False
if 0 not in col and sum(col) != S:
return False
dia1 = [data[i][i] for i in range(len(data))]
dia2 = [data[i][-1-i] for i in range(len(data))]
if (0 in dia1 and sum(dia1) >= S) or (0 not in dia1 and sum(dia1) != S):
return False
if (0 in dia2 and sum(dia2) >= S) or (0 not in dia2 and sum(dia2) != S):
return False
return True
def search_line(data, rest, line, indices):
k = len(indices)
updated = False
for combi in combinations(rest, k):
if sum(combi) == S-sum(line):
updated = True
for perm in permutations(combi, k):
push(data, rest, indices, perm)
return updated
def push(data, rest, coords, nums):
new_data = deepcopy(data)
for (i, j), num in zip(coords, nums):
new_data[i][j] = num
new_rest = rest-set(nums)
stack.append((new_data, new_rest))
def dfs(data, rest, k):
for i, row in enumerate(data):
indices = [(i, j) for j, cell in enumerate(row) if cell == 0]
if len(indices) == k:
if search_line(data, rest, row, indices):
return True
for j, col in enumerate(zip(*data)):
indices = [(i, j) for i, cell in enumerate(col) if cell == 0]
if len(indices) == k:
if search_line(data, rest, col, indices):
return True
dia1 = [data[i][i] for i in range(len(data))]
indices = [(i, i) for i in range(len(data)) if data[i][i] == 0]
if len(indices) == k:
if search_line(data, rest, dia1, indices):
return True
dia2 = [data[i][-1-i] for i in range(len(data))]
indices = [(i, -1-i) for i in range(len(data)) if data[i][-1-i] == 0]
if len(indices) == k:
if search_line(data, rest, dia2, indices):
return True
return False
stack = [(data, rest)]
while stack:
data, rest = stack.pop()
if not check(data):
continue
if not rest:
break
for k in range(1, len(data)+1):
if dfs(data, rest, k):
break
return data
if __name__ == '__main__':
#This part is using only for self-testing.
def check_solution(func, in_square):
SIZE_ERROR = "Wrong size of the answer."
MS_ERROR = "It's not a magic square."
NORMAL_MS_ERROR = "It's not a normal magic square."
NOT_BASED_ERROR = "Hm, this square is not based on given template."
result = func(in_square)
#check sizes
N = len(result)
if len(result) == N:
for row in result:
if len(row) != N:
print(SIZE_ERROR)
return False
else:
print(SIZE_ERROR)
return False
#check is it a magic square
# line_sum = (N * (N ** 2 + 1)) / 2
line_sum = sum(result[0])
for row in result:
if sum(row) != line_sum:
print(MS_ERROR)
return False
for col in zip(*result):
if sum(col) != line_sum:
print(MS_ERROR)
return False
if sum([result[i][i] for i in range(N)]) != line_sum:
print(MS_ERROR)
return False
if sum([result[i][N - i - 1] for i in range(N)]) != line_sum:
print(MS_ERROR)
return False
#check is it normal ms
good_set = set(range(1, N ** 2 + 1))
user_set = set([result[i][j] for i in range(N) for j in range(N)])
if good_set != user_set:
print(NORMAL_MS_ERROR)
return False
#check it is the square based on input
for i in range(N):
for j in range(N):
if in_square[i][j] and in_square[i][j] != result[i][j]:
print(NOT_BASED_ERROR)
return False
return True
assert check_solution(checkio,
[[2, 7, 6],
[9, 5, 1],
[4, 3, 0]]), "1st example"
assert check_solution(checkio,
[[0, 0, 0],
[0, 5, 0],
[0, 0, 0]]), "2nd example"
assert check_solution(checkio,
[[1, 15, 14, 4],
[12, 0, 0, 9],
[8, 0, 0, 5],
[13, 3, 2, 16]]), "3rd example"
June 21, 2020