#!/usr/bin/python3

import sys
import copy

base_states=[]
states=[]

filename=(len(sys.argv) > 1) and sys.argv[1] or "37_26_seats.txt"

base_states=[list(s.rstrip()) for s in open(filename, "r")]
states=copy.deepcopy(base_states)

def get_adjacent_seats(states, y, x, seat_map):
    occupied=0
    start_y=y-1
    if y == 0:
        start_y=0

    end_y=y+1
    if y >= len(states)-1:
        end_y=y

    start_x=x-1
    if x == 0:
        start_x=0

    end_x=x+1
    if x >= len(states[y])-1:
        end_x=x

    for check_y in range(start_y, end_y+1):
        for check_x in range(start_x, end_x+1):
            if check_y == y and check_x == x:
                continue
            if states[check_y][check_x] == "#":
                occupied+=1

    return occupied

def get_nearest_seats(states, y, x, seat_map):
    occupied=0
    for seat in seat_map[y][x]:
        if states[seat[0]][seat[1]] == "#":
            occupied+=1

    return occupied

def calc_nearest_seats(states):
    seats={}
    for y,row in enumerate(states):
        for x,seat in enumerate(row):
            # only do anything if the seat is a seat
            if seat == '.':
                continue
            daseats={}
            for check_y,other_row in enumerate(states):
                for check_x,other_seat in enumerate(other_row):
                    # only do anything if the seat is a seat
                    if other_seat == '.':
                        continue
                    # skip outselves
                    if check_x == x and check_y == y:
                        continue
                    if abs(check_x - x) == abs(check_y - y):
                        # this is one of the four diagonals
                        if check_x > x and check_y > y:
                            # down and right, getting further away
                            # so if we've got a value ignore, other
                            # wise add it
                            if 'dr' not in daseats:
                                daseats['dr']=[check_y,check_x]
                        elif check_x > x and check_y < y:
                            # up and right, getting closer
                            daseats['ur']=[check_y,check_x]
                        elif check_x < x and check_y < y:
                            # up and left, getting closer, so always replace
                            daseats['ul']=[check_y,check_x]
                        elif check_x < x and check_y > y:
                            # down and left getting further away
                            if 'dl' not in daseats:
                                daseats['dl']=[check_y,check_x]
                    elif check_x > x and check_y == y:
                        # heading right, getting further away
                        # so if we've got a value ingore, other
                        # wise add it
                        if 'r' not in daseats:
                            daseats['r']=[check_y,check_x]
                    elif check_x < x and check_y == y:
                        # heading left, getting closer
                        daseats['l']=[check_y,check_x]
                    elif check_y > y and check_x == x:
                        # heading down, getting further away
                        if 'd' not in daseats:
                            daseats['d']=[check_y,check_x]
                    elif check_y < y and check_x == x:
                        # heading up
                        daseats['u']=[check_y,check_x]
            if not y in seats:
                seats[y]={x: [seat for seat in daseats.values()]}
            seats[y][x]=[seats for seats in daseats.values()]

    return seats


def apply_rules(states, check=get_adjacent_seats, max_people=4, seat_map=None):
    new_states=copy.deepcopy(states)

    for y in range(0, len(states)):
        for x in range(0, len(states[0])):
            if states[y][x] == ".":
                continue
            adj_seats=check(states, y, x, seat_map)
            if states[y][x] == 'L' and adj_seats == 0:
                new_states[y][x]='#'
            elif states[y][x] == '#' and adj_seats >= max_people:
                new_states[y][x]='L'

    return new_states

def get_seats(states):
    return ''.join([''.join(s) for s in states])

def print_layout(states, indent_string=""):
    for line in states:
        print(indent_string, ''.join(line))

    print()

previous_seats=get_seats(states)
states=apply_rules(states)
seats=get_seats(states)

while seats!=previous_seats:
    states=apply_rules(states)
    previous_seats=seats
    seats=get_seats(states)

print("Part 1:")
print("  Final state:")
print_layout(states, "  ")
print("  Occupied: {}".format(seats.count('#')))

# now do part 2!
states=copy.deepcopy(base_states)
seat_map=calc_nearest_seats(states)

previous_seats=get_seats(states)
states=apply_rules(states,get_nearest_seats,5,seat_map)
seats=get_seats(states)

while seats!=previous_seats:
    states=apply_rules(states,get_nearest_seats,5,seat_map)
    previous_seats=seats
    seats=get_seats(states)

print("Part 2:")
print("  Final state:")
print_layout(states, "  ")
print("  Occupied: {}".format(seats.count('#')))
