Skip to content
Snippets Groups Projects
main.py 5.38 KiB
import copy
import json
from pickle import FALSE

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import subprocess

tasks = {
    "lines": "0b148d64.json",
    "grids": "90f3ed37.json",
    "pour": "d4f3cd78.json",
    "cross": "e21d9049.json",
    "stripes": "f8c80d96.json"
}


def run_prolog_program(program, curr_dir=""):
    # Construct the command to run SICStus Prolog
    command = ['/usr/local/sicstus4.8.0/bin/sicstus', '--noinfo', '--goal', program]

    # Execute the command
    # result = subprocess.run(command, capture_output=True, text=True, cwd=curr_dir)

    try:
        result = subprocess.run(
            command,
            capture_output=True,
            text=True,
            cwd=curr_dir,
            timeout=30  # Timeout after 30 seconds
        )
        if result.returncode != 0:
            print("SICStus Prolog reported an error:")
            return result.stderr
        else:
            # Print the output
            return result.stdout
    except subprocess.TimeoutExpired:
        print("SICStus Prolog timed out.")
        return None

def hex_to_rgb(hex_color):
    # Remove the '#' character if it exists
    hex_color = hex_color.lstrip('#')
    # Convert the hexadecimal values to RGB tuple
    return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))


# # Example usage
# hex_color = "#34A2FE"
# rgb_color = hex_to_rgb(hex_color)
# print("RGB Color:", rgb_color)

def plot_grid(rgb_grid):
    height, width = rgb_grid.shape[:2]

    # Create a plot with gridlines
    fig, ax = plt.subplots()

    # Display the RGB grid
    ax.imshow(rgb_grid, extent=(0, width, 0, height), interpolation='none')

    # Set gridlines and customize appearance
    ax.set_xticks(np.arange(0.001, width, 1), minor=True)
    ax.set_yticks(np.arange(0.001, height, 1), minor=True)
    ax.grid(which='minor', color='grey', linestyle='-', linewidth=1)

    # Hide major ticks and labels
    ax.tick_params(which='major', bottom=False, left=False, labelbottom=False, labelleft=False)
    ax.tick_params(which='minor', bottom=False, left=False)

    # Remove extra whitespace
    plt.subplots_adjust(left=0.005, right=0.995, top=0.995, bottom=0.005)
    ax.set_aspect('equal')  # Ensure pixels are square

    plt.show()


def rgb_lookup():
    df = pd.read_csv("colours.csv")
    return np.array(df['colour'].apply(lambda x: tuple(int(x[i:i + 2], 16) for i in (0, 2, 4))).values.tolist())


def colour_lookup():
    df = pd.read_csv("colours.csv")
    return df['name'].values


def nd_sort(arr):
    # Reshape the array to a 2D array where each row represents a combination of indices and values
    D1, D2, D3 = arr.shape
    arr_flat = arr.reshape(-1, D3)

    # Create index arrays for the first two dimensions
    idx0, idx1 = np.meshgrid(np.arange(D1), np.arange(D2), indexing='ij')
    idx0 = idx0.flatten()
    idx1 = idx1.flatten()

    # Combine indices and values
    combined = np.column_stack((idx0, idx1, arr_flat))

    # Sort based on the first two indices
    sorted_indices = np.lexsort((combined[:, 1], combined[:, 0]))
    sorted_combined = combined[sorted_indices]

    # Reshape back to the original array shape if needed
    sorted_arr = sorted_combined[:, 2:].reshape(D1, D2, D3)
    return sorted_arr


def FOL2prolog(preds):
    return '\n'.join(['\n'.join(x) for x in preds])


def prolog2FOL_array(prolog):
    arr = np.array(prolog.split('\n'))
    return arr[arr != '']


def FOL2grid(preds):
    preds = preds.reshape(-1)
    # will fail if missing preds (all squares need to specify a colour)
    preds = np.char.replace(preds, r"output_colour(", "")
    preds = np.char.replace(preds, r").", "")
    strs = np.array(np.char.split(preds, ",").tolist())
    idx = strs[..., :2].astype(int)
    col_val = colour_names2idx(strs[..., -1])

    # shape = idx.max(0) + 1
    # idx_1d = idx[..., 0] * shape[1] + idx[..., 1]
    # col_val[idx_1d].reshape(shape)

    out = np.zeros(idx.max(0) + 1)
    for i in range(len(idx)):
        out[tuple(idx[i])] = col_val[i]

    return out


def colour_names2idx(colour_names):
    df = pd.read_csv("colours.csv")
    colour_to_idx = {colour: idx for idx, colour in zip(df.int, df.name)}
    # Vectorize the mapping function
    vectorized_mapping = np.vectorize(colour_to_idx.get)

    incorrect_colours = np.unique(colour_names[~np.isin(colour_names, df.name)])
    if len(incorrect_colours) > 0:
        raise IndexError(f"Incorrect colour names: {'|'.join(incorrect_colours)}\n"
                         f"Must be one of: {','.join(df.name)}")
    # Apply the mapping to the 2D array
    arr_idx = vectorized_mapping(colour_names)
    return arr_idx


def load_jsons():
    # Load a single ARC task
    task = load_task()


def load_task(json_file='data/training/0a938d79.json'):
    with open(json_file) as f:
        task = json.load(f)
    return task


def grid2FOL(input_grid, prefix):
    grid = np.array(input_grid)
    col_grid = colour_lookup()[grid]
    str_grid = np.array(
        [[f"{prefix}_colour({i},{j},{col_grid[i, j]})." for j in range(grid.shape[1])] for i in range(grid.shape[0])])
    return str_grid


def array_and_plot_grid(input_grid):
    rgb_grid = grid2rgb(input_grid)
    plot_grid(rgb_grid)
    return rgb_grid


def grid2rgb(input_grid):
    grid = np.array(input_grid).astype(np.int64)
    rgb_grid = rgb_lookup()[grid]
    return rgb_grid


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    load_jsons()