/*
 * libsyncml - A syncml protocol implementation
 * Copyright (C) 2005  Armin Bauer <armin.bauer@opensync.org>
 * Copyright (C) 2007-2009  Michael Bell <michael.bell@opensync.org>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
 *
 */

#include "sml_elements_internals.h"

#include "sml_error_internals.h"
#include "sml_support.h"

#include "sml_parse.h"
#include "sml_command.h"
#include "data_sync_api/sml_location.h"

SmlAnchor*
smlAnchorNew (const gchar *last,
              const gchar *next,
              GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%s, %s, %p)", __func__, VA_STRING(last), VA_STRING(next), error);
	CHECK_ERROR_REF
	
	SmlAnchor *anchor = smlTryMalloc0(sizeof(SmlAnchor), error);
	if (!anchor)
		goto error;
	
	anchor->last = g_strdup(last);
	anchor->next = g_strdup(next);
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, anchor);
	return anchor;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}

void
smlAnchorFree (SmlAnchor *anchor)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, anchor);
	smlAssert(anchor);
	if (anchor->last)
		smlSafeCFree(&(anchor->last));
		
	if (anchor->next)
		smlSafeCFree(&(anchor->next));
		
	smlSafeFree((gpointer *)&anchor);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

const gchar*
smlAnchorGetLast (SmlAnchor *anchor)
{
	return anchor->last;
}

const gchar*
smlAnchorGetNext (SmlAnchor *anchor)
{
	return anchor->next;
}

gsize
smlHeaderGetSessionID (SmlHeader *header)
{
	return header->sessionID;
}

gsize
smlHeaderGetMessageID (SmlHeader *header)
{
	return header->messageID;
}

SmlProtocolVersion
smlHeaderGetProtocolVersion (SmlHeader *header)
{
	return header->version;
}

SmlProtocolType
smlHeaderGetProtocolType (SmlHeader *header)
{
	return header->protocol;
}

SmlLocation*
smlHeaderGetSource (SmlHeader *header)
{
	return header->source;
}

SmlLocation*
smlHeaderGetTarget (SmlHeader *header)
{
	return header->target;
}

gsize
smlHeaderGetMaxMsgSize (SmlHeader *header)
{
	return header->maxmsgsize;
}

const gchar*
smlHeaderGetResponseURI (SmlHeader *header)
{
	return header->responseURI;
}

void
smlHeaderFree (SmlHeader *header)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, header);
	
	if (header->emi)
		smlSafeCFree(&(header->emi));

	if (header->source)
		g_object_unref(header->source);
		
	if (header->target)
		g_object_unref(header->target);
	
	if (header->responseURI)
		smlSafeCFree(&(header->responseURI));
	
	smlSafeFree((gpointer *)&header);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

SmlItem*
smlItemNew (gsize size,
            GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%i, %p)", __func__, size, error);
	CHECK_ERROR_REF
		
	SmlItem *item = smlTryMalloc0(sizeof(SmlItem), error);
	if (!item)
		goto error;
	
	item->refCount = 1;
	item->size = size;
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, item);
	return item;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}

/* If data is NULL, this call is the same if smlItemNew */
SmlItem*
smlItemNewForData (const gchar *data,
                   gsize size,
                   GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i, %p)", __func__, data, size, error);
	CHECK_ERROR_REF
	
	SmlItem *item = smlItemNew(size, error);
	if (!item)
		goto error;
	
	if (data) {
		if (!smlItemAddData(item, data, size, error))
			goto error;
	}
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, item);
	return item;
error:
	if (item)
		smlItemUnref(item);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}

SmlItem*
smlItemRef (SmlItem *item)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, item);
	smlAssert(item);
	
	g_atomic_int_inc(&(item->refCount));
	
	smlTrace(TRACE_EXIT, "%s: New refcount: %i", __func__, item->refCount);
	return item;
}

void
smlItemUnref(SmlItem *item)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, item);
	smlAssert(item);
	
	if (g_atomic_int_dec_and_test(&(item->refCount))) {
		smlTrace(TRACE_INTERNAL, "%s: Refcount == 0!", __func__);
		
		if (item->source)
			g_object_unref(item->source);
			
		if (item->target)
			g_object_unref(item->target);
		
		if (item->anchor)
			smlAnchorFree(item->anchor);
		
		if (item->buffer)
			xmlBufferFree(item->buffer);
		
		if (item->contenttype)
			smlSafeCFree(&(item->contenttype));
		
		smlSafeFree((gpointer *)&item);
	}
	
	smlTrace(TRACE_EXIT, "%s: %i", __func__, item ? item->refCount : 0);
}

gboolean
smlItemAddData (SmlItem *item,
                const gchar *data,
                gsize size,
                GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %i, %p)", __func__, item, data, size, error);
	CHECK_ERROR_REF
	
	if (item->size && xmlBufferLength(item->buffer) + size > item->size) {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Unable to add data. size limit reached");
		goto error;
	}
	
	if (data) {
		if (!item->buffer) {
			if (item->size)
				item->buffer = xmlBufferCreateSize(item->size);
			else
				item->buffer = xmlBufferCreateSize(size);
		}
		
		if (xmlBufferAdd(item->buffer, (xmlChar *)data, size) != 0) {
			g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Unable to add data.");
			goto error;
		}
	}
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return FALSE;
	
}

/** Checks if the item is complete */
gboolean
smlItemCheck (SmlItem *item)
{
	smlAssert(xmlBufferLength(item->buffer) >= 0);
	smlAssert(item);
	if (!item->size)
		return TRUE;
		
	if ((unsigned int)xmlBufferLength(item->buffer) != item->size)
	{
		smlTrace(TRACE_INTERNAL, "%s: failed because size (%d != %d) does not match (%s).",
			__func__, item->size,
			xmlBufferLength(item->buffer), VA_STRING((char *)xmlBufferContent(item->buffer)));
		return FALSE;
	}
		
	return TRUE;
}

gboolean
smlItemHasData (SmlItem *item)
{
	smlAssert(item);
	return item->buffer ? TRUE : FALSE;
}

/** Returns the data of the item. The data will not be freed when the item is unrefd. After
 * this call, smlItemHasData will report FALSE */
gboolean
smlItemStealData (SmlItem *item,
                  gchar **data,
                  gsize *size,
                  GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p, %p)", __func__, item, data, size, error);
	CHECK_ERROR_REF
	smlAssert(size);
	
	if (!smlItemCheck(item)) {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Item check failed");
		goto error;
	}
	
	*size = xmlBufferLength(item->buffer);
	*data = g_strndup((const char *) xmlBufferContent(item->buffer), *size);
	xmlBufferFree(item->buffer);
	item->buffer = NULL;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return FALSE;
}

/** Returns a const pointer to the data of the item. the data will disappear when the data is derefd */
gboolean
smlItemGetData (SmlItem *item,
                gchar **data,
                gsize *size,
                GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p, %p)", __func__, item, data, size, error);
	CHECK_ERROR_REF
	
	if (!smlItemCheck(item)) {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Item check failed");
		goto error;
	}
	
	*data = (char *)xmlBufferContent(item->buffer);
	*size = xmlBufferLength(item->buffer);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return FALSE;
}

gsize
smlItemGetSize (SmlItem *item)
{
	smlAssert(item);
	return xmlBufferLength(item->buffer);
}

const gchar*
smlItemGetContent (SmlItem *item)
{
	smlAssert(item);
	return (gchar *)xmlBufferContent(item->buffer);
}

void
smlItemSetSource (SmlItem *item,
                  SmlLocation *source)
{
	smlAssert(item);
	smlAssert(source);
	
	item->source = source;
	g_object_ref(source);
}

SmlLocation*
smlItemGetSource (SmlItem *item)
{
	smlAssert(item);
	
	return item->source;
}

void
smlItemSetTarget (SmlItem *item,
                  SmlLocation *target)
{
	smlAssert(item);
	smlAssert(target);
	
	item->target = target;
	g_object_ref(target);
}

SmlLocation*
smlItemGetTarget (SmlItem *item)
{
	smlAssert(item);
	
	return item->target;
}

void
smlItemSetRaw (SmlItem *item,
               gboolean raw)
{
	smlAssert(item);
	
	item->raw = raw;
}

void
smlItemSetMoreData (SmlItem *item,
                    gboolean enable)
{
	item->moreData = enable;
}

gboolean
smlItemGetMoreData (SmlItem *item)
{
	return item->moreData;
}

gboolean
smlItemSetContentType (SmlItem *item,
                       const gchar *ct,
                       GError **error)
{
	CHECK_ERROR_REF
	item->contenttype = g_strdup(ct);
	return TRUE;
}

const gchar*
smlItemGetContentType (SmlItem *item)
{
	return item->contenttype;
}

SmlItem *
smlItemGetFragment (SmlItem *orig_item,
                    gsize start,
                    gsize space,
                    GError **error)
{
	smlAssert(orig_item);

	const char *data = (char *)xmlBufferContent(orig_item->buffer);
	gsize size = xmlBufferLength(orig_item->buffer);

	SmlItem *frag_item = smlItemNewForData(data + start, space, error);
	if (!frag_item)
		goto error;

	if (start + space < size)
		frag_item->moreData = TRUE;
	else
		frag_item->moreData = FALSE;
		
	frag_item->target = orig_item->target;
	if (frag_item->target)
		g_object_ref(frag_item->target);

	frag_item->source = orig_item->source;
	if (frag_item->source)
	g_object_ref(frag_item->source);

	frag_item->contenttype = g_strdup(orig_item->contenttype);

	return frag_item;
error:
	if (frag_item)
		smlItemUnref(frag_item);
	return NULL;
}

SmlCred*
smlCredNewFromString (const gchar *type,
                      const gchar *format,
                      const gchar *data,
                      GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%s, %s, %s, %p)", __func__, VA_STRING(data), VA_STRING(type), VA_STRING(format), error);
	CHECK_ERROR_REF

	SmlAuthType smlType = SML_AUTH_TYPE_UNKNOWN;
	SmlFormatType smlFormat = SML_FORMAT_TYPE_UNKNOWN;
	
	if (!type || !strcmp(type, SML_AUTH_BASIC)) {
		smlType = SML_AUTH_TYPE_BASIC;
	} else if (!strcmp(type, SML_AUTH_MD5)) {
		smlType = SML_AUTH_TYPE_MD5;
	} else {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Unknown type - %s.", type);
		goto error;
	}

	if (!format || !strcmp(format, SML_BASE64)) {
		smlFormat = SML_FORMAT_TYPE_BASE64;
	} else {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "SyncML credential: Unknown format - %s.", format);
		goto error;
	}

	if (!data)  {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Data is missing in %s.", __func__);
		goto error;
	}

	smlTrace(TRACE_EXIT, "%s", __func__);
	return smlCredNew(smlType, smlFormat, data, NULL, error);
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;	
}

SmlCred*
smlCredNewAuth (SmlAuthType type,
                const gchar *username,
                const gchar *password,
                GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%d, %s, %p)", __func__, type, VA_STRING(username), error);
	CHECK_ERROR_REF

	SmlCred *cred = NULL;

	if (username ==  NULL || !strlen(username)) {
		g_set_error(error, SML_ERROR, SML_ERROR_INTERNAL_MISCONFIGURATION,
			"If authentication should be used then the username must be set.");
		goto error;
	}
	if (password ==  NULL || !strlen(password)) {
		g_set_error(error, SML_ERROR, SML_ERROR_INTERNAL_MISCONFIGURATION,
			"If authentication should be used then the password must be set.");
		goto error;
	}

        cred = smlTryMalloc0(sizeof(SmlCred), error);
        if (!cred)
                goto error;

	cred->refCount = 1;
	cred->format = SML_FORMAT_TYPE_BASE64;
	cred->type = type;
	cred->username = g_strdup(username);
	cred->password = g_strdup(password);

	smlTrace(TRACE_EXIT, "%s", __func__);
	return cred;
error:
	if (cred)
		smlSafeFree((gpointer *)&cred);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;	
}

SmlCred*
smlCredNew (SmlAuthType type,
            SmlFormatType format,
            const gchar *data,
            const gchar *username,
            GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%s, %d, %d, %p)", __func__, VA_STRING(data), type, format, error);
	CHECK_ERROR_REF

        SmlCred *cred = smlTryMalloc0(sizeof(SmlCred), error);
        if (!cred)
                goto error;

	cred->type = type;
	cred->format = format;
	cred->data = g_strdup(data);
	if (username)
		cred->username = g_strdup(username);
	else
		cred->username = NULL;
	cred->refCount = 1;

	smlTrace(TRACE_EXIT, "%s: %p", __func__, cred);
	return cred;
error:
	if (cred->data)
		smlSafeCFree(&(cred->data));
	if (cred->username)
		smlSafeCFree(&(cred->username));
	if (cred)
		smlSafeFree((gpointer *)&cred);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}

void
smlCredRef (SmlCred *cred)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, cred);
	smlAssert(cred);
	
	g_atomic_int_inc(&(cred->refCount));
	
	smlTrace(TRACE_EXIT, "%s: New refcount: %i", __func__, cred->refCount);
}

void
smlCredUnref (SmlCred *cred)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, cred);
	smlAssert(cred);
	
	if (g_atomic_int_dec_and_test(&(cred->refCount))) {
		smlTrace(TRACE_INTERNAL, "%s: Refcount == 0!", __func__);
		
		if (cred->data)
			smlSafeCFree(&(cred->data));
		if (cred->username)
			smlSafeCFree(&(cred->username));
		if (cred->password)
			smlSafeCFree(&(cred->password));
			
		smlSafeFree((gpointer *)&cred);
	}
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

const gchar*
smlCredGetUsername (SmlCred *cred)
{
	smlAssert(cred);
	return cred->username;
}

const gchar*
smlCredGetPassword (SmlCred *cred)
{
	smlAssert(cred);
	return cred->password;
}

const gchar*
smlCredGetData (SmlCred *cred)
{
	smlAssert(cred);
	return cred->data;
}

void
smlCredSetData (SmlCred *cred,
                gchar *data)
{
	smlAssert(cred);
	cred->data = data;
}

SmlAuthType
smlCredGetType (SmlCred *cred)
{
	smlAssert(cred);
	return cred->type;
}

SmlChal*
smlChalNew (SmlAuthType type,
            GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%u, %p)", __func__, type, error);
	CHECK_ERROR_REF
	SmlChal *chal = NULL;

	/* allocate memory */
	smlAssert(type != SML_AUTH_TYPE_UNKNOWN);
	chal = smlTryMalloc0(sizeof(SmlChal), error);
	if (!chal)
		goto error;
	chal->refCount = 1;
	chal->type = type;
	chal->format = SML_FORMAT_TYPE_BASE64;

	if (type == SML_AUTH_TYPE_MD5)
	{
		/* A nonce must be generated for this type.
		 *     minimum strength:  2^128
		 *     strength per byte: 2^6 - 2 > 2^5
		 *     needed bytes:      128 / 5 < 130 / 5 = 26
		 */
		chal->nonce_plain = smlRandStr(26, TRUE);
		chal->nonce_length = 26;
		chal->nonce_b64 = g_base64_encode(
						(const unsigned char *) chal->nonce_plain,
						chal->nonce_length);
		if (!chal->nonce_b64) {
			g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
				"The nonce of the challenge cannot be base64 encoded.");
			goto error;
		}
	}

	smlTrace(TRACE_EXIT, "%s", __func__);
	return chal;
error:
	if (chal->nonce_plain)
		smlSafeCFree(&(chal->nonce_plain));
	if (chal->nonce_b64)
		smlSafeCFree(&(chal->nonce_b64));
	if (chal)
		smlSafeFree((gpointer *)&chal);
	smlTrace(TRACE_EXIT_ERROR, "%s - %s", __func__, (*error)->message);
	return NULL;
}

SmlChal*
smlChalNewFromBinary(SmlAuthType type,
                     const gchar *nonce,
                     gsize length,
                     GError **error)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	CHECK_ERROR_REF
	SmlChal *chal = NULL;

	/* only syncml:auth-md5 needs a nonce */
	smlAssert(type == SML_AUTH_TYPE_MD5);

	/* allocate memory */
	chal = smlTryMalloc0(sizeof(SmlChal), error);
	if (!chal)
		goto error;
	chal->refCount = 1;

	/* copy nonce */
	chal->type = SML_AUTH_TYPE_MD5;
	chal->format = SML_FORMAT_TYPE_BASE64;
	chal->nonce_plain = g_strndup(nonce, length);
	chal->nonce_length = length;

	/* create base64 nonce */
	chal->nonce_b64 = g_base64_encode((const unsigned char *) nonce, length);
	if (!chal->nonce_b64) {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
			"The base64 encoding of the nonce failed.");
		goto error;
	}

	smlTrace(TRACE_EXIT, "%s", __func__);
	return chal;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s - %s", __func__, (*error)->message);
	return NULL;
}

SmlChal*
smlChalNewFromBase64 (SmlAuthType type,
                      const gchar *nonce,
                      GError **error)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	CHECK_ERROR_REF
	SmlChal *chal = NULL;

	/* only syncml:auth-md5 needs a nonce */
	smlAssert(type == SML_AUTH_TYPE_MD5);

	/* allocate memory */
	chal = smlTryMalloc0(sizeof(SmlChal), error);
	if (!chal)
		goto error;
	chal->refCount = 1;

	/* copy nonce */
	chal->type = SML_AUTH_TYPE_MD5;
	chal->format = SML_FORMAT_TYPE_BASE64;
	chal->nonce_b64 = g_strdup(nonce);

	/* create binary nonce */
	size_t length = 0;
	chal->nonce_plain = (char *) g_base64_decode(nonce, &length);
	if (!chal->nonce_plain || length < 1) {
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
			"The base64 encoded nonce cannot be decoded.");
		goto error;
	}
	chal->nonce_length = length;

	smlTrace(TRACE_EXIT, "%s", __func__);
	return chal;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s - %s", __func__, (*error)->message);
	return NULL;
}

void
smlChalRef (SmlChal *chal)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	smlAssert(chal);
	
	g_atomic_int_inc(&(chal->refCount));
	
	smlTrace(TRACE_EXIT, "%s: New refcount: %i", __func__, chal->refCount);
}

void
smlChalUnref (SmlChal *chal)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	smlAssert(chal);
	
	if (g_atomic_int_dec_and_test(&(chal->refCount))) {
		smlTrace(TRACE_INTERNAL, "%s: Refcount == 0!", __func__);
		
		if (chal->nonce_plain)
			smlSafeCFree(&(chal->nonce_plain));

		if (chal->nonce_b64)
			smlSafeCFree(&(chal->nonce_b64));

		smlSafeFree((gpointer *)&chal);
	}
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

/* Base 64 only */
const gchar*
smlChalGetNonce (SmlChal *chal)
{
	smlAssert(chal);
	return chal->nonce_b64;
}

SmlAuthType
smlChalGetType (SmlChal *chal)
{
	smlAssert(chal);
	return chal->type;
}

