import gym

from joatmon.ai.models.reinforcement.hybrid.td3 import TD3Model
from joatmon.ai.processor import RLProcessor
from joatmon.ai.trainer import TD3Trainer
from joatmon.callback import (
    CallbackList,
    Loader,
    Renderer,
    TrainLogger,
    ValidationLogger
)
from joatmon.game import SokobanEnv
from joatmon.ai.memory import RingMemory
from joatmon.ai.random import OrnsteinUhlenbeck as RandomProcess


def create_env():
    try:
        environment = gym.make(
            'Sokoban-Medium-v0', **{
                'xmls': 'game/assets/sokoban/xmls/',
                'sprites': 'game/assets/sokoban/sprites/'
            }
        )
    except Exception as ex:
        print(str(ex))
        environment = SokobanEnv(
            **{
                'xml': 'medium.xml',
                'xmls': 'game/assets/sokoban/xmls/',
                'sprites': 'game/assets/sokoban/sprites/'
            }
        )
    return environment


def run():
    memory = RingMemory()
    processor = RLProcessor()

    model = TD3Model(tau=1e-3, in_features=3, out_features=2)

    experiment = '.'
    case = '.'
    run_name = 'test'
    run_path = 'saves/02.ddpg/{}/{}/{}/'.format(experiment, case, run_name)

    random_process = RandomProcess(decay_steps=1200000)
    environment = create_env()
    callbacks = CallbackList(
        [
            TrainLogger(run_path=run_path, interval=100),
            Loader(model=model, run_path=run_path, interval=1000),
            Renderer(environment=environment)
        ]
    )
    agent = TD3Trainer(
        environment=environment,
        memory=memory,
        processor=processor,
        model=model,
        callbacks=callbacks,
        random_process=random_process
    )
    agent.train(max_action=200, max_episode=2, warmup=12, replay_interval=32)

    random_process = RandomProcess(sigma=0.0, sigma_min=0.0, decay_steps=1200000)
    environment = create_env()
    callbacks = CallbackList(
        [
            ValidationLogger(run_path=run_path, interval=1),
            Loader(model=model, run_path=run_path, interval=1000),
            Renderer(environment=environment),
        ]
    )
    agent = TD3Trainer(
        environment=environment,
        memory=memory,
        processor=processor,
        model=model,
        callbacks=callbacks,
        random_process=random_process
    )
    agent.evaluate(max_action=200)


if __name__ == '__main__':
    run()