#!/usr/bin/python3

# For testing the zmq servers

import zmq
from time import sleep
from threading import Thread
from ptcommon.ptdm_message import Message


def cleanup():
    global continue_listening
    continue_listening = False
    thread.join()
    zmq_socket.close()


def test_request(test_name, message_request_id, parameters, expected_response_id, expected_response_parameters):
    print("[Test " + test_name + "]")

    message = Message.from_parts(message_request_id, parameters)
    zmq_socket.send_string(message.to_string())

    response_string = zmq_socket.recv_string()
    response = Message.from_string(response_string)

    if (response.message_id() != expected_response_id):
        print("Test failed. Unexpected response: " + str(response.message_id()))
        cleanup()

    for i in range(len(expected_response_parameters)):
        if (expected_response_parameters[i] != response.parameters()[i]):
            print("Test failed. Unexpected response parameter. Expected: " + str(expected_response_parameters[i]) + ", got:" + str(response.parameters()[i]))
            cleanup()

    print("[Test passed]")


def listen_thread():

    print("Connecting to publish server...")
    zmq_context_listen = zmq.Context()
    zmq_socket = zmq_context_listen.socket(zmq.SUB)
    zmq_socket.setsockopt_string(zmq.SUBSCRIBE, "")
    zmq_socket.connect("tcp://127.0.0.1:3781")
    print("Connected to publish server")

    sleep(0.5)

    while continue_listening:
        poller = zmq.Poller()
        poller.register(zmq_socket, zmq.POLLIN)
        events = poller.poll(500)

        for i in range(len(events)):
            message_string = zmq_socket.recv_string()
            message = Message.from_string(message_string)
            # print("Received publish event: " + message.message_friendly_string())


def test_ping():
    print("*** Test ping ***")
    test_request("Ping", Message.REQ_PING, [], Message.RSP_PING, [])


def _test_get_brightness():
    print("*** Test getting brightness ***")
    test_request("Get brightness", Message.REQ_GET_BRIGHTNESS, [], Message.RSP_GET_BRIGHTNESS, [])


def _test_set_brightness():
    print("*** Test setting brightness ***")

    test_request("Set brightness to 10", Message.REQ_SET_BRIGHTNESS, ["10"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 8", Message.REQ_SET_BRIGHTNESS, ["8"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 6", Message.REQ_SET_BRIGHTNESS, ["6"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 4", Message.REQ_SET_BRIGHTNESS, ["4"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 2", Message.REQ_SET_BRIGHTNESS, ["2"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 0", Message.REQ_SET_BRIGHTNESS, ["0"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 2", Message.REQ_SET_BRIGHTNESS, ["2"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 4", Message.REQ_SET_BRIGHTNESS, ["4"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 6", Message.REQ_SET_BRIGHTNESS, ["6"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 8", Message.REQ_SET_BRIGHTNESS, ["8"], Message.RSP_SET_BRIGHTNESS, [])
    sleep(1)
    test_request("Set brightness to 10", Message.REQ_SET_BRIGHTNESS, ["10"], Message.RSP_SET_BRIGHTNESS, [])


def _test_inc_dec_brightness():
    print("*** Test incrementing/decrementing brightness ***")

    test_request("Decrement brightness", Message.REQ_DECREMENT_BRIGHTNESS, [], Message.RSP_DECREMENT_BRIGHTNESS, [])
    sleep(1)
    test_request("Decrement brightness", Message.REQ_DECREMENT_BRIGHTNESS, [], Message.RSP_DECREMENT_BRIGHTNESS, [])
    sleep(1)
    test_request("Decrement brightness", Message.REQ_DECREMENT_BRIGHTNESS, [], Message.RSP_DECREMENT_BRIGHTNESS, [])
    sleep(1)
    test_request("Increment brightness", Message.REQ_INCREMENT_BRIGHTNESS, [], Message.RSP_INCREMENT_BRIGHTNESS, [])
    sleep(1)
    test_request("Increment brightness", Message.REQ_INCREMENT_BRIGHTNESS, [], Message.RSP_INCREMENT_BRIGHTNESS, [])
    sleep(1)
    test_request("Increment brightness", Message.REQ_INCREMENT_BRIGHTNESS, [], Message.RSP_INCREMENT_BRIGHTNESS, [])


def test_all_brightness():
    _test_get_brightness()
    sleep(1)
    _test_set_brightness()
    sleep(1)
    _test_inc_dec_brightness()


def test_device_id():
    print("*** Test getting device ID ***")

    while True:
        resp = input("Are you running on a pi-top (1) or a pi-topCEED (2)?")
        if resp == "1" or resp == "2":
            break

    device_id = str(int(resp) + 1)
    test_request("Get DeviceID", Message.REQ_GET_DEVICE_ID, [], Message.RSP_GET_DEVICE_ID, [device_id])


def test_screen_blank_unblank():
    print("*** Test screen blanking ***")
    test_request("Blank Screen", Message.REQ_BLANK_SCREEN, [], Message.RSP_BLANK_SCREEN, [])
    sleep(1)
    test_request("Unblank Screen", Message.REQ_UNBLANK_SCREEN, [], Message.RSP_UNBLANK_SCREEN, [])

print("Starting thread...")

continue_listening = True
thread = Thread(target=listen_thread)
thread.start()

print("Connecting to request server...")
zmq_context_send = zmq.Context()
zmq_socket = zmq_context_send.socket(zmq.REQ)
zmq_socket.connect("tcp://127.0.0.1:3782")
print("Connected to request server.")

sleep(0.5)

test_ping()
sleep(1)
test_all_brightness()
sleep(1)
test_device_id()
sleep(1)
test_screen_blank_unblank()
sleep(1)

print("*** End of tests ***")

print("Waiting for any further events...")
sleep(5)

print("Closing sockets...")
cleanup()
print("Done.")
