import gym

from joatmon.ai.models import DQNModel
from joatmon.ai.processor import RLProcessor
from joatmon.ai.trainer import DQNTrainer
from joatmon.callback import (
    CallbackList,
    Loader,
    Renderer,
    TrainLogger,
    ValidationLogger
)
from joatmon.game.sokoban import (
    SokobanEnv
)
from joatmon.ai.memory import RingMemory
from joatmon.ai.policy import (
    EpsilonGreedyPolicy as EGreedy,
    GreedyQPolicy as GreedyQ
)


def create_env():
    try:
        environment = gym.make(
            'Sokoban-Medium-v0', **{
                'xmls': 'game/assets/sokoban/xmls/',
                'sprites': 'game/assets/sokoban/sprites/'
            }
        )
    except Exception as ex:
        print(str(ex))
        environment = SokobanEnv(
            **{
                'xml': 'medium.xml',
                'xmls': 'game/assets/sokoban/xmls/',
                'sprites': 'game/assets/sokoban/sprites/'
            }
        )
    return environment


def run():
    memory = RingMemory()
    processor = RLProcessor()

    model = DQNModel(in_features=3, out_features=4)

    experiment = '.'
    case = '.'
    run_name = 'test'
    run_path = 'saves/01.dqn/{}/{}/{}/'.format(experiment, case, run_name)

    environment = create_env()
    agent = DQNTrainer(
        environment=environment,
        memory=memory,
        processor=processor,
        model=model,
        callbacks=CallbackList(
            [
                TrainLogger(run_path=run_path, interval=10),
                Loader(model=model, run_path=run_path, interval=10),
                Renderer(environment=environment)
            ]
        ),
        train_policy=EGreedy(min_value=0.1),
        test_policy=GreedyQ()
    )
    agent.train(max_episode=2400, warmup=1200000, max_action=200)

    environment = create_env()
    agent = DQNTrainer(
        environment=environment,
        memory=memory,
        processor=processor,
        model=model,
        callbacks=CallbackList(
            [
                ValidationLogger(run_path=run_path, interval=1),
                Renderer(environment=environment),
            ]
        ),
        train_policy=EGreedy(min_value=0.1),
        test_policy=GreedyQ()
    )
    agent.evaluate(max_action=200)


if __name__ == '__main__':
    run()