#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <time.h>
#include <stdarg.h>

#include "xmlsec/xmlsec.h"
#include "xmlsec/xmltree.h"
#include "xmlsec/xmldsig.h"
#include "xmlsec/crypto.h"
#include "xmlsec/parser.h"

#include "atobase/private/pall.h"

#include "atointernal.h"
#include "atotypes.h"
#include "atostmtkn.h"
#include "atostm.h"

static const char *_library = ATO_STM_LIBRARY;
static const char *_module = ATO_STM_MODULE_TKN;
static unsigned long _moduleid = ATO_STM_MODULEID_TKN;
static ato_eLoglevel _loglevel = ATO_LOG_WARN;

#define _ATO_RSTR_PREFIX "<AtoRSTR stkey=\""
/*********************************************************************************/
struct _ato_StmTkn {
    char *key;
    ato_String *atorstr;
    ato_Xml *xmlresponse;
    ato_String *encryptedtoken;
    ato_String *hmac;
    char *samlid;
    time_t expirytime;
};

/*********************************************************************************/
static void _setresponsens(ato_Ctx *ctx, ato_Xml *xml)
{
    assert(xml != NULL);
    ato_xml_addnamespace(ctx, xml, "s", "http://www.w3.org/2003/05/soap-envelope");
    ato_xml_addnamespace(ctx, xml, "a", "http://www.w3.org/2005/08/addressing");
    ato_xml_addnamespace(ctx, xml, "u", "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd");
    ato_xml_addnamespace(ctx, xml, "o", "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd");
    ato_xml_addnamespace(ctx, xml, "i", "http://schemas.xmlsoap.org/ws/2005/05/identity");
    ato_xml_addnamespace(ctx, xml, "sig", "http://www.w3.org/2000/09/xmldsig#");
    ato_xml_addnamespace(ctx, xml, "wsa", "http://www.w3.org/2005/08/addressing");
    ato_xml_addnamespace(ctx, xml, "wsp", "http://schemas.xmlsoap.org/ws/2004/09/policy");
    ato_xml_addnamespace(ctx, xml, "wst", "http://docs.oasis-open.org/ws-sx/ws-trust/200512");
    ato_xml_addnamespace(ctx, xml, "wsu", "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd");
    ato_xml_addnamespace(ctx, xml, "saml", "urn:oasis:names:tc:SAML:2.0:assertion");
    ato_xml_addnamespace(ctx, xml, "xenc", "http://www.w3.org/2001/04/xmlenc#");
}

static void _processresponse(ato_Ctx *ctx, ato_StmTkn *st, ato_String *response)
{
    struct exception_context *the_exception_context = ato__ctx_ec(ctx);
    static const char *function = "_processresponse";
    static const char *stkeyName = "/AtoRSTR/@stkey";
    static const char *EncryptedDataName1 = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:RequestedSecurityToken/xenc:EncryptedData";
    static const char *EncryptedDataName2 = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:RequestedSecurityToken/saml:EncryptedAssertion/xenc:EncryptedData";
    static const char *Assertion         = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:RequestedSecurityToken/saml:Assertion";
    static const char *BinarySecretName = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:RequestedProofToken/wst:BinarySecret";
    static const char *SamlId = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:RequestedAttachedReference/o:SecurityTokenReference/o:KeyIdentifier";
    static const char *ResponseBody_Expires = "/AtoRSTR/s:Envelope/s:Body/wst:RequestSecurityTokenResponseCollection/wst:RequestSecurityTokenResponse/wst:Lifetime/wsu:Expires";
    ato_String *hmackeyb64str = NULL;
    char *hmackeyb64 = NULL;
    char *expirytime = NULL;
    int errcode = ATO_ERR_OK;
    void *xnode = NULL;
    char *buf = NULL;
    char *stkey = NULL;
    ato_Log *log = ato_ctx_log(ctx);

    Try {
        ato_xml_create(ctx, &(st->xmlresponse), response);

        _setresponsens(ctx, st->xmlresponse);

        ato_xml_nodevalue(st->xmlresponse, NULL, stkeyName, &stkey, FALSE);
        assert(stkey);
        if (st->key == NULL) {
            st->key = stkey;
        } else {
            assert(strcmp(st->key, stkey) == 0);
            free(stkey);
        }

        if (ato_xml_nodevalue(st->xmlresponse, NULL, BinarySecretName, &hmackeyb64, FALSE) == NULL)
            Throw ATO_CTX_NEWERR(ctx, ATO_STM_ERR_NETRECEIVER, "Could not find HMAC element");
        ato_str_create(&hmackeyb64str, hmackeyb64, strlen(hmackeyb64), TRUE);

        errcode = ato_base64decode(ctx, &(st->hmac), hmackeyb64str);
        if (errcode != ATO_ERR_OK)
            Throw ATO_CTX_NEWERR(ctx, ATO_STM_ERR_NETRECEIVER, "Could not find prooftoken");

        xnode = ato_xml_findnode(st->xmlresponse, NULL, EncryptedDataName2, NULL);
        if (xnode == NULL) {
            ATO_LOG_MSG(log, ATO_LOG_WARN, "Could not find encrypted assertion (SAML2) - trying encrypted data (SAML1)");
            xnode = ato_xml_findnode(st->xmlresponse, NULL, EncryptedDataName1, NULL);
        }
        if (xnode == NULL) {
            ATO_LOG_MSG(log, ATO_LOG_WARN, "Could not find encrypted data (SAML1) - trying Assertion");
            xnode = ato_xml_findnode(st->xmlresponse, NULL, Assertion, NULL);
        }
        if (xnode == NULL) {
            ATO_LOG_MSG(log, ATO_LOG_WARN, "Could not find Assertion");
            Throw ATO_CTX_NEWERR(ctx, ATO_STM_ERR_NETRECEIVER, "Could not find SAML Token");
        }
        ato_xml_valueraw(st->xmlresponse, xnode, &buf, FALSE);
        ato_str_create(&(st->encryptedtoken), buf, strlen(buf), TRUE);

        if (ato_xml_nodevalue(st->xmlresponse, NULL, SamlId, &(st->samlid), FALSE) == NULL)
            Throw ATO_CTX_VNEWERR(ctx, ATO_STM_ERR_NETRECEIVER, strlen(SamlId), "Could not find KeyIdentifier (SAMLID) element %s", SamlId);

        if (ato_xml_nodevalue(st->xmlresponse, NULL, ResponseBody_Expires, &expirytime, FALSE) == NULL)
            Throw ATO_CTX_NEWERR(ctx, ATO_STM_ERR_NETRECEIVER, "Could not find expiry time element");

        // e.g. 2011-06-30T03:28:07.835Z
        st->expirytime = ato_str2time(expirytime);
    } Catch(errcode) {
    }

    hmackeyb64str = ato_str_free(hmackeyb64str);
    expirytime = ato_free(expirytime);
    ATO_RETHROW_ONERR(errcode);
}

/*********************************************************************************/
static void _setloglevel(ato_eLoglevel level)
{
    _loglevel = level;
}

/*********************************************************************************/
int ato__stmtkn_init(void)
{
    static bool invoked = FALSE;
    if (invoked) return ATO_ERR_OK;
    invoked = TRUE;

    ato_initfnloglevel(_library, _module, _moduleid, _loglevel, _setloglevel);
    return ATO_ERR_OK;
}

void ato__stmtkn_deinit(void)
{
}
/*********************************************************************************/
static void _load_rstr(ato_Ctx *ctx, ato_StmTkn *st, ato_String *rstr, const char *stkey)
{
    static const char *s1 = _ATO_RSTR_PREFIX;
    static const char *s2 = "\">";
    static const char *s3 = "</AtoRSTR>";
    char *content = NULL;
    size_t t_len = 0;
    size_t len = 0;
    ATO_IGNORE(ctx);

    if (!ato_strstartswith(ato_str_value(rstr), _ATO_RSTR_PREFIX)) {
        assert(!ato_isnullorempty(stkey));
        st->key = ato_strdup(stkey, 0);

        content = calloc(strlen(s1) + strlen(stkey) + strlen(s2) + ato_str_len(rstr) + strlen(s3) + 10, sizeof(char));
        assert(content != NULL);
        content[0] = '\0';
        strcat(content, s1);
        strcat(content, stkey);
        strcat(content, s2);
        len = strlen(content);
        memcpy(content + len, ato_str_value(rstr), ato_str_len(rstr));
        len += ato_str_len(rstr);
        t_len = strlen(s3);
        memcpy(content + len, s3, t_len);
        len += t_len;
        content[len] = '\0';
        ato_str_create(&(st->atorstr), content, len, TRUE);
    } else {
        ato_str_dup(&(st->atorstr), rstr);
        assert(st->atorstr != NULL);
    }
}
/*********************************************************************************/
void ato__stmtkn_create(ato_Ctx *ctx, ato_StmTkn **obj, ato_String *rstr, const char *stkey)
{
    static const char *function = "ato__stmtkn_create";
    int errcode = ATO_ERR_OK;
    ato_StmTkn *st = NULL;

    ATO_CTX_FN_START(ctx);

    ATO_ASSERT_ISNOTALLOCATED(obj);
    assert(rstr != NULL);

    *obj = st = calloc(1, sizeof(ato_StmTkn));
    assert(st != NULL);

    _load_rstr(ctx, st, rstr, stkey);
    _processresponse(ctx, st, st->atorstr);

    ATO_CTX_FN_END(ctx, errcode);
}

void ato_stmtkn_free(ato_StmTkn *st)
{
    if (st == NULL) return;
    ato_xml_free(st->xmlresponse);
    st->hmac = ato_str_free(st->hmac);
    st->encryptedtoken = ato_str_free(st->encryptedtoken);
    st->atorstr = ato_str_free(st->atorstr);
    st->samlid = ato_free(st->samlid);
    st->key = ato_free(st->key);
    free(st);
}

/*********************************************************************************/
ato_String *ato_stmtkn_xml(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->atorstr;
}

const char *ato_stmtkn_key(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->key;
}

ato_String *ato_stmtkn_prooftoken(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->hmac;
}

ato_String *ato_stmtkn_assertion(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->encryptedtoken;
}

const char *ato_stmtkn_samlid(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->samlid;
}

bool ato_stmtkn_isexpired(ato_StmTkn *st)
{
    double diff = 0;
    time_t now;
    assert(st != NULL);
    time(&now);
    {
        // for debugging..
        //char bnow[101], bexp[101];
        //ato_strftime_utc(bnow, 100, &(now), FALSE);
        //ato_strftime_utc(bexp, 100, &(st->expirytime), TRUE);
        diff = difftime(st->expirytime, now);
        return diff <= 0;
    }
}

time_t ato_stmtkn_expirytime(ato_StmTkn *st)
{
    assert(st != NULL);
    return st->expirytime;
}
