/*
 * libsyncml - A syncml protocol implementation
 * Copyright (C) 2005  Armin Bauer <armin.bauer@opensync.org>
 * Copyright (C) 2008  Michael Bell <michael.bell@opensync.org>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
 *
 */

#include "sml_auth_internals.h"

#include <libsyncml/sml_support.h>
#include "libsyncml/sml_error_internals.h"

#include <libsyncml/sml_session.h>
#include <libsyncml/sml_elements.h>
#include <libsyncml/sml_command.h>
#include <libsyncml/sml_md5.h>

#include <string.h>

static SmlStatus* _smlAuthHeaderReply(SmlSession *session, SmlErrorType code, SmlAuthType auth, GError **error);

void _status_callback(SmlSession *session, SmlStatus *status, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, session, status, userdata);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

void _header_callback(SmlSession *session, SmlHeader *header, SmlCred *cred, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p, %p)", __func__, session, header, cred, userdata);
	smlAssert(session);
	smlAssert(userdata);
	SmlStatus *reply = NULL;
	SmlAuthenticator *auth = userdata;
	GError *error = NULL;

	if (smlSessionGetType(session) == SML_SESSION_TYPE_CLIENT) {
		/* If this is an OMA DS client then there will be no
		 * authentication. Only OMA DS servers can request an
		 * authentication from the remote peer.
		 */
		g_warning("This is an OMA DS client. An OMA DS client should not use this authentication callback.");
		smlTrace(TRACE_INTERNAL, "%s: This is an OMA DS client and so auth is not supported.", __func__);
		auth->state = SML_NO_ERROR;
		if (auth->enabled)
			smlTrace(TRACE_ERROR,
				"%s: authentication is enabled but this is an OMA DS client.");
	}
	else if (!cred)
	{
		if (!auth->enabled)
		{
			smlTrace(TRACE_INTERNAL, "%s: Auth is disabled and no cred given", __func__);
			auth->state = SML_NO_ERROR;
		} else {
			/* auth->enabled */
			if (auth->state != SML_AUTH_ACCEPTED)
			{
				/* Ask the remote peer to authenticate.
				 * This is not an error because it is really dangerous
				 * for the remote peer to send its password without
				 * be asked for it. Please note that syncml:auth-basic
				 * is a clear text password.
				 */
				smlTrace(TRACE_INTERNAL, "%s: Auth is required", __func__);
				auth->state = SML_ERROR_AUTH_REQUIRED;
				smlSessionSetAuthenticate(session, TRUE);
			}
			else
			{
				/* auth->state == SML_AUTH_ACCEPTED */
				smlTrace(TRACE_INTERNAL, "%s: Auth is already accepted.", __func__);
				auth->state = SML_AUTH_ACCEPTED;
			}
		}
	} else {
		/* cred available */
		smlTrace(TRACE_INTERNAL, "%s: Cred is \"%s\"", __func__, VA_STRING(smlCredGetData(cred)));
		if (!auth->enabled)
		{
			smlTrace(TRACE_INTERNAL, "%s: Cred received but unwanted", __func__);
			auth->state = SML_AUTH_ACCEPTED;
		} else {
			/* auth->enabled */
			if (auth->verifyCallback)
			{
				/* The callback needs the following stuff:
				 * 	Chal
				 * 	Cred
				 * 	LocName(username)
				 */
				if (auth->verifyCallback(smlSessionGetChal(session), cred,
				                         sml_location_get_name(smlSessionGetSource(session)),
							 auth->verifyCallbackUserdata, &error))
				{
					auth->state = SML_AUTH_ACCEPTED;
				} else {

					if (error) {
						smlTrace(TRACE_ERROR, "%s: %s", __func__, error->message);
						GError *cb_error = error;
						error = NULL;
						g_set_error(&error, SML_ERROR, SML_ERROR_AUTH_REJECTED,
						            "Auth rejected for username %s. %s",
						            sml_location_get_name(smlSessionGetSource(session)),
						            cb_error->message);
						g_error_free(cb_error);
					} else {
						g_set_error(&error, SML_ERROR, SML_ERROR_AUTH_REJECTED,
						            "Auth rejected for username %s. %s",
						            sml_location_get_name(smlSessionGetSource(session)));
					}
					smlSessionDispatchEvent(session, SML_SESSION_EVENT_ERROR, NULL, NULL, NULL, error);
					g_error_free(error);
					error = NULL;
					auth->state = SML_ERROR_AUTH_REJECTED;
				}
			} else {
				smlTrace(TRACE_INTERNAL, "%s: No verify callback set", __func__);
				auth->state = SML_ERROR_AUTH_REJECTED;
			}
		}
	}
	
	if (auth->state == SML_ERROR_AUTH_REJECTED) {
		smlTrace(TRACE_INTERNAL, "%s: Ending session due to wrong / missing creds", __func__);
		smlSessionSetEnd(session, TRUE);
	}
	
	reply = _smlAuthHeaderReply(session, auth->state, auth->type, &error);
	if (!reply)
		goto error;
	
	if (!smlSessionSendReply(session, reply, &error)) {
		smlStatusUnref(reply);
		goto error;
	}
	
	smlStatusUnref(reply);

	if (!smlSessionGetEstablished(session) &&
	    !smlSessionGetEnd(session) &&
	    !smlSessionGetAuthenticate(session) &&
	    smlSessionGetType(session) == SML_SESSION_TYPE_SERVER)
	{
		smlSessionSetEstablished(session, TRUE);
		smlSessionDispatchEvent(
			session, SML_SESSION_EVENT_ESTABLISHED,
			NULL, NULL, NULL, NULL);
	}
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return;
error:
	smlSessionDispatchEvent(session, SML_SESSION_EVENT_ERROR, NULL, NULL, NULL, error);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, error->message);
	g_error_free(error);
	return;
}

gchar*
smlAuthGetCredString (SmlAuthType type,
                      const gchar *username,
                      const gchar *password,
                      const gchar *b64_nonce,
                      GError **error)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	CHECK_ERROR_REF
	char *cred = NULL;

	switch (type) {
		case SML_AUTH_TYPE_BASIC:

			smlTrace(TRACE_INTERNAL, "%s - SML_AUTH_TYPE_BASIC", __func__);
			char *plain = g_strjoin(":", username, password, NULL);
			cred = g_base64_encode((unsigned char *) plain, strlen(plain));
			if (!cred) {
				g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
					"The syncml:auth-basic credential cannot be base64 encoded.");
				smlSafeCFree(&plain);
				goto error;
			}
			smlSafeCFree(&plain);

			break;
		case SML_AUTH_TYPE_MD5:

			smlTrace(TRACE_INTERNAL, "%s - SML_AUTH_TYPE_MD5", __func__);
			/* How does syncml:auth-md5 works?
			 *
			 * base64(
			 *        md5(
			 *            base64(
			 *                   md5(
			 *                       username + ":" + password
			 *                      )
			 *                  ) +
			 *            ":" + nonce
			 *           )
			 *       )
			 */

			/* Let's determine the string for the comparison. */
			;
			char *auth = g_strjoin (":", username, password, NULL);
			unsigned char digest[16];
			smlMD5GetDigest (auth, strlen(auth), digest);
			smlSafeCFree(&auth);
			cred = g_base64_encode(digest, 16);
			if (!cred) {
				g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
					"The username:password part of the syncml:auth-md5 "\
					"credential cannot be base64 encoded.");
				goto error;
			}
			auth = g_strjoin (":", cred, b64_nonce, NULL);
			smlSafeCFree(&cred);
			smlMD5GetDigest (auth, strlen(auth), digest);
			smlSafeCFree(&auth);
			cred = g_base64_encode(digest, 16);
			if (!cred) {
				g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
					"The complete syncml:auth-md5 credential cannot be base64 encoded.");
				goto error;
			}

			break;
		default:
			smlTrace(TRACE_ERROR, "%s - unknown authentication type", __func__);
			g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Unknown auth format");
			goto error;
	}

	smlTrace(TRACE_EXIT, "%s", __func__);
	return cred;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s - cannot create credential string");
	if (*error == NULL)
		g_set_error(error, SML_ERROR, SML_ERROR_GENERIC,
		            "Cannot create credential string for user %s.",
		            username);
	return NULL;
}

gboolean
smlAuthVerify (SmlChal *chal,
               SmlCred *cred,
               const gchar *username,
               const gchar *password,
               GError **error)
{
	smlTrace(TRACE_ENTRY, "%s", __func__);
	CHECK_ERROR_REF

	/* If no Chal is send to the client but the client offers a Cred
	 * then we accept it. Theoretically this is not 100% perfect if
	 * syncml:auth-md5 is used but if the client offers it then we
	 * accept it because some implementations like UIQ 3 are broken
	 * and implemented the protocol in a wrong way if we answer with
	 * error 407 and a fresh nonce.
	 *
	 * SECURITY NOTE: This part of the code makes libsyncml vulnerable
	 * SECURITY NOTE: against replay attacks.
	 *
	 */
	if (chal && smlChalGetType(chal) != smlCredGetType(cred))
	{
		if (smlChalGetType(chal) == SML_AUTH_TYPE_BASIC &&
		    smlCredGetType(cred) == SML_AUTH_TYPE_MD5)
		{
			/* This is an upgrade to more security.
			 * So it is acceptable.
			 */
			smlTrace(TRACE_INTERNAL, "%s - replace syncml:auth-basic by syncml:auth-md5", __func__);
		} else {
			/* This is a security event. */
			g_set_error(error, SML_ERROR, SML_ERROR_AUTH_REJECTED,
				"The type of the authentication was changed to a lower security level.");
			goto error;
		}
	}
	smlTrace(TRACE_INTERNAL, "%s - authentication security policy ok", __func__);

	char *wanted = NULL;
	switch (smlCredGetType(cred)) {
		case SML_AUTH_TYPE_BASIC:
			smlTrace(TRACE_INTERNAL, "%s - SML_AUTH_TYPE_BASIC", __func__);
			wanted = smlAuthGetCredString(SML_AUTH_TYPE_BASIC, username, password, NULL, error);
			break;
		case SML_AUTH_TYPE_MD5:
			smlTrace(TRACE_INTERNAL, "%s - SML_AUTH_TYPE_MD5", __func__);
			if (chal)
				wanted = smlAuthGetCredString(
						SML_AUTH_TYPE_MD5,
						username, password,
						smlChalGetNonce(chal), error);
			else
				wanted = smlAuthGetCredString(
						SML_AUTH_TYPE_MD5,
						username, password,
						"", error);
			break;
		default:
			smlTrace(TRACE_ERROR, "%s - unknown authentication type", __func__);
			g_set_error(error, SML_ERROR, SML_ERROR_GENERIC, "Unknown auth format");
			goto error;
	}
	smlTrace(TRACE_INTERNAL, "%s - credential string calculated", __func__);

	/* compare the authentication string */
	if (strcmp(wanted, smlCredGetData(cred)))
	{
		smlTrace(TRACE_INTERNAL, "%s - credentials mismatch", __func__);
		smlSafeCFree(&wanted);
		goto error;
	}
	smlSafeCFree(&wanted);
			
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s - auth rejected");
	if (*error == NULL)
		g_set_error(error, SML_ERROR, SML_ERROR_AUTH_REJECTED,
		            "Authentication rejected for username %s.",
		            username);
	return FALSE;
}

SmlAuthenticator*
smlAuthNew (GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, error);
	CHECK_ERROR_REF
	SmlAuthenticator *auth = smlTryMalloc0(sizeof(SmlAuthenticator), error);
	if (!auth)
		goto error;

	auth->enabled = TRUE;
	auth->state = SML_ERROR_AUTH_REQUIRED;
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, auth);
	return auth;
error:
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}


void
smlAuthFree (SmlAuthenticator *auth)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, auth);
	smlAssert(auth);
	
	smlSafeFree((gpointer *)&auth);
	
	smlTrace(TRACE_EXIT, "%s", __func__);	
}

gboolean
smlAuthRegister (SmlAuthenticator *auth,
                 SmlManager *manager,
                 GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, auth, manager, error);
	CHECK_ERROR_REF
	smlAssert(auth);
	smlAssert(manager);
	
	smlManagerRegisterHeaderHandler(manager, _header_callback, _status_callback, auth);
	
	smlTrace(TRACE_EXIT, "%s", __func__);
	return TRUE;
}

void
smlAuthSetState (SmlAuthenticator *auth,
                 SmlErrorType type)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i)", __func__, auth, type);
	smlAssert(auth);
	
	auth->state = type;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

static SmlStatus*
_smlAuthHeaderReply (SmlSession *session,
                     SmlErrorType code,
                     SmlAuthType auth,
                     GError **error)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i, %i, %p)", __func__, session, code, auth, error);
	CHECK_ERROR_REF

	// the session structure is from the viewpoint of this machine
	// SourceRef and TargetRef of the status MUST use the viewpoint
	// of the remote peer
	// we have to revert source and target
	smlTrace(TRACE_INTERNAL, "%s: SourceRef: %s --> TargetRef: %s",
		__func__,
		VA_STRING(sml_location_get_uri(smlSessionGetTarget(session))),
		VA_STRING(sml_location_get_uri(smlSessionGetSource(session))));
	SmlStatus *reply = smlStatusNew(code, 0, smlSessionGetLastRecvMsgID(session), smlSessionGetTarget(session), smlSessionGetSource(session), SML_COMMAND_TYPE_HEADER, error);
	if (!reply)
		goto error;
	
	if (code == SML_ERROR_AUTH_REJECTED ||
	    code == SML_ERROR_AUTH_REQUIRED) {
		SmlChal *chal = smlChalNew(auth, error);
		if (!chal)
			goto error;
		smlStatusSetChal(reply, chal);
		smlSessionSetChal(session, chal);
		smlChalUnref(chal);
		chal = NULL;
	}
	
	smlTrace(TRACE_EXIT, "%s: %p", __func__, reply);
	return reply;
error:
	if (reply)
		smlStatusUnref(reply);
	smlTrace(TRACE_EXIT_ERROR, "%s: %s", __func__, (*error)->message);
	return NULL;
}

void smlAuthSetVerifyCallback(SmlAuthenticator *auth, SmlAuthVerifyCb callback, void *userdata)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %p, %p)", __func__, auth, callback, userdata);
	smlAssert(auth);
	auth->verifyCallback = callback;
	auth->verifyCallbackUserdata = userdata;
	smlTrace(TRACE_EXIT, "%s", __func__);
}

void smlAuthSetEnable(SmlAuthenticator *auth, gboolean enabled)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i)", __func__, auth, enabled);
	smlAssert(auth);
	
	auth->enabled = enabled;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

gboolean smlAuthIsEnabled(SmlAuthenticator *auth)
{
	smlTrace(TRACE_ENTRY, "%s(%p)", __func__, auth);
	smlAssert(auth);
	
	smlTrace(TRACE_EXIT, "%s - %u", __func__, auth->enabled);
	return auth->enabled;
}

void smlAuthSetType(SmlAuthenticator *auth, SmlAuthType type)
{
	smlTrace(TRACE_ENTRY, "%s(%p, %i)", __func__, auth, type);
	smlAssert(auth);
	smlAssert(type != SML_AUTH_TYPE_UNKNOWN);
	
	auth->type = type;
	
	smlTrace(TRACE_EXIT, "%s", __func__);
}

