/* BEGIN LICENSE
  * Copyright © Blue Mind SAS, 2012-2023
  *
  * This file is part of BlueMind. BlueMind is a messaging and collaborative
  * solution.
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of either the GNU Affero General Public License as
  * published by the Free Software Foundation (version 3 of the License).
  *
  * This program is distributed in the hope that it will be useful,
  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  *
  * See LICENSE.txt
  * END LICENSE
  */
package net.bluemind.webmodule.authenticationfilter;

import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Base64.Decoder;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.auth0.jwt.JWT;
import com.auth0.jwt.exceptions.JWTDecodeException;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.google.common.base.Strings;

import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.RequestOptions;
import io.vertx.core.json.DecodeException;
import io.vertx.core.json.JsonObject;
import net.bluemind.core.api.auth.AuthDomainProperties;
import net.bluemind.core.api.auth.AuthTypes;
import net.bluemind.core.api.fault.ErrorCode;
import net.bluemind.core.api.fault.ServerFault;
import net.bluemind.keycloak.api.IKeycloakUids;
import net.bluemind.keycloak.utils.endpoint.KeycloakEndpoints;
import net.bluemind.keydb.sessiondata.CodeVerifierCache;
import net.bluemind.keydb.sessiondata.SessionData;
import net.bluemind.keydb.sessiondata.SessionDataStore;
import net.bluemind.webmodule.authenticationfilter.internal.AuthenticationCookie;
import net.bluemind.webmodule.authenticationfilter.internal.ExternalCreds;

public class OpenIdHandler extends AbstractAuthHandler implements Handler<HttpServerRequest> {
	private static final Logger logger = LoggerFactory.getLogger(OpenIdHandler.class);

	public static final String JWT_SESSION_STATE = "session_state";

	private class SessionConsumer implements Consumer<SessionData> {
		private final HttpServerRequest request;
		private final String domainUid;
		private final String realmId;
		private final boolean internalAuth;
		private final JsonObject jwtToken;
		private final String openIdClientSecret;

		public SessionConsumer(HttpServerRequest request, String domainUid, String openIdClientSecret,
				boolean internalAuth, JsonObject jwtToken) {
			this.request = request;
			this.domainUid = domainUid;
			this.realmId = IKeycloakUids.realmId(domainUid);
			this.openIdClientSecret = openIdClientSecret;
			this.internalAuth = internalAuth;
			this.jwtToken = jwtToken;
		}

		@Override
		public void accept(SessionData sessionData) {
			decorateResponse(sessionData);

			SessionDataStore.get().put(sessionData.setOpenId(jwtToken, realmId, openIdClientSecret, internalAuth,
					sessionData.createStamp + SessionDataStore.SESSIONID_REFRESH_PERIOD));

			if (logger.isInfoEnabled()) {
				logger.info("[{}] Session {} for user {} created, JWT SID: {}", request.path(), sessionData.authKey,
						sessionData.loginAtDomain, jwtToken.getValue(JWT_SESSION_STATE));
			}
		}

		public void decorateResponse(SessionData sessionData) {
			MultiMap headers = request.response().headers();

			JsonObject cookie = new JsonObject();
			cookie.put("sid", sessionData.authKey);
			cookie.put("domain_uid", domainUid);
			cookie.put("user_uid", sessionData.userUid);
			AuthenticationCookie.add(headers, AuthenticationCookie.OPENID_SESSION, cookie.encode());

			Claim pubpriv = JWT.decode(jwtToken.getString("access_token")).getClaim("bm_pubpriv");
			boolean privateComputer = "private".equals(pubpriv.asString());
			AuthenticationCookie.add(headers, AuthenticationCookie.BMPRIVACY, Boolean.toString(privateComputer));
		}
	}

	private static final Decoder b64UrlDecoder = Base64.getUrlDecoder();

	@Override
	public void handle(HttpServerRequest request) {
		List<String> forwadedFor = new ArrayList<>(request.headers().getAll("X-Forwarded-For"));
		forwadedFor.add(request.remoteAddress().host());

		JsonObject stateRequestParameter = getRequestStateParam(request, forwadedFor);
		if (Strings.isNullOrEmpty(request.params().get("code")) || stateRequestParameter == null) {
			if (logger.isDebugEnabled()) {
				logger.error("[{}][{}] null or empty 'code' or invalid 'state' request parameter", forwadedFor,
						request.path());
			}

			request.response().headers().add(HttpHeaders.LOCATION, "/");
			request.response().setStatusCode(302);
			request.response().end();
			return;
		}

		if (sessionExists(request)) {
			request.response().headers().add(HttpHeaders.LOCATION,
					getRedirectTo(request, stateRequestParameter.getString("path")));
			request.response().setStatusCode(302);
			request.response().end();
			return;
		}

		String codeVerifierKey = stateRequestParameter.getString("codeVerifierKey");
		String codeVerifier = CodeVerifierCache.getAndRemove(codeVerifierKey);
		if (Strings.isNullOrEmpty(codeVerifier)) {
			error(request, new Throwable("OpenId codeVerifier key '" + codeVerifierKey
					+ "' not found in cache (expired ?), ignore request from [" + String.join(",", forwadedFor) + "]"));
			return;
		}

		String domainUid = stateRequestParameter.getString("domain_uid");
		String realmId = IKeycloakUids.realmId(domainUid);
		Map<String, String> domainSettings = DomainsSettings.forDomain(domainUid);

		try {
			boolean internalAuth = AuthTypes.get(domainSettings.get(AuthDomainProperties.AUTH_TYPE.name()))
					.useBlueMindKeycloak();

			httpClient.request(new RequestOptions().setMethod(HttpMethod.POST)
					.setAbsoluteURI(tokenEndpoint(realmId, internalAuth, domainSettings)), reqHandler -> {
						String openIdClientSecret = domainSettings
								.get(AuthDomainProperties.OPENID_CLIENT_SECRET.name());

						if (reqHandler.succeeded()) {
							HttpClientRequest r = reqHandler.result();
							r.response(respHandler -> tokenEndpointResponseHandler(request, forwadedFor, domainUid,
									internalAuth, openIdClientSecret, stateRequestParameter, respHandler));

							MultiMap headers = r.headers();
							headers.add(HttpHeaders.ACCEPT_CHARSET, StandardCharsets.UTF_8.name());
							headers.add(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded");
							String params = "grant_type=authorization_code";

							if (internalAuth) {
								params += "&client_id=" + IKeycloakUids.clientId(realmId);
							} else {
								params += "&client_id="
										+ encode(domainSettings.get(AuthDomainProperties.OPENID_CLIENT_ID.name()));
							}
							params += "&client_secret=" + encode(openIdClientSecret);
							params += "&code=" + encode(request.params().get("code"));
							params += "&code_verifier=" + encode(codeVerifier);
							params += "&redirect_uri="
									+ encode("https://" + request.authority().host() + "/auth/openid");
							params += "&scope=openid";

							byte[] postData = params.getBytes(StandardCharsets.UTF_8);
							headers.add(HttpHeaders.CONTENT_LENGTH, Integer.toString(postData.length));
							r.write(Buffer.buffer(postData));
							r.end();
						} else {
							error(request, reqHandler.cause());
						}
					});

			return;
		} catch (Exception e) {
			error(request, e);
		}

		request.response().end();
	}

	private JsonObject getRequestStateParam(HttpServerRequest request, List<String> forwadedFor) {
		String stateAsString = request.params().get("state");
		if (stateAsString == null || stateAsString.isBlank()) {
			if (logger.isDebugEnabled()) {
				logger.error("[{}][{}] 'state' parameter is null or blank", forwadedFor, request.path());
			}
			return null;
		}

		try {
			return new JsonObject(new String(b64UrlDecoder.decode(stateAsString.getBytes())));
		} catch (DecodeException | IllegalArgumentException | IndexOutOfBoundsException e) {
			if (logger.isDebugEnabled()) {
				logger.error("[{}][{}] invalid 'state' parameter", forwadedFor, request.path(), e);
			}

			return null;
		}
	}

	private void tokenEndpointResponseHandler(HttpServerRequest request, List<String> forwadedFor, String domainUid,
			boolean internalAuth, String openIdClientSecret, JsonObject stateRequestParameter,
			AsyncResult<HttpClientResponse> response) {
		if (response.succeeded()) {
			HttpClientResponse resp = response.result();
			if (resp.statusCode() != 200) {
				error(request, new ServerFault(
						"Invalid request to domain " + domainUid + " token endpoint, HTTP code: " + resp.statusCode(),
						ErrorCode.INVALID_PARAMETER));
				return;
			}

			resp.body(body -> {
				JsonObject jwtToken;
				try {
					jwtToken = new JsonObject(new String(body.result().getBytes()));
				} catch (DecodeException | NullPointerException e) {
					String bodyAsString = new String(body.result().getBytes());
					error(request,
							new ServerFault(
									"Invalid response from domain " + domainUid + " token endpoint: "
											+ bodyAsString.substring(0, Math.min(30, bodyAsString.length())),
									ErrorCode.INVALID_PARAMETER));
					return;
				}

				SessionData sessionData = SessionDataStore.get()
						.getFromSessionState((String) jwtToken.getValue(JWT_SESSION_STATE));
				if (sessionData != null) {
					logger.info(
							"BlueMind session {} already exists for JWT session_state {}, don't create a new one, redirect to {}",
							sessionData.authKey, jwtToken.getValue(JWT_SESSION_STATE),
							getRedirectTo(request, stateRequestParameter.getString("path")));
					new SessionConsumer(request, domainUid, openIdClientSecret, internalAuth, jwtToken)
							.decorateResponse(sessionData);
					request.response().headers().add(HttpHeaders.LOCATION,
							getRedirectTo(request, stateRequestParameter.getString("path")));
					request.response().setStatusCode(302);
					request.response().end();
				} else {
					validateToken(request, forwadedFor, domainUid, internalAuth,
							getRedirectTo(request, stateRequestParameter.getString("path")), jwtToken,
							new SessionConsumer(request, domainUid, openIdClientSecret, internalAuth, jwtToken));
				}
			});
		} else {
			error(request, response.cause());
		}
	}

	private String getRedirectTo(HttpServerRequest request, String redirectTo) {
		if (redirectTo == null) {
			redirectTo = "/";
		}

		if (logger.isDebugEnabled()) {
			logger.debug("[{}] Redirect to {}", request.path(), redirectTo);
		}
		return redirectTo;
	}

	private boolean sessionExists(HttpServerRequest request) {
		String sessionId = request.getHeader("BMSessionId");
		if (sessionId == null) {
			return false;
		}

		return SessionDataStore.get().getIfPresent(sessionId) != null;
	}

	private String tokenEndpoint(String realmId, boolean internalAuth, Map<String, String> domainSettings) {
		String endpoint;
		if (internalAuth) {
			endpoint = KeycloakEndpoints.tokenEndpoint(realmId);
		} else {
			endpoint = domainSettings.get(AuthDomainProperties.OPENID_TOKEN_ENDPOINT.name());
		}
		return endpoint;
	}

	private String encode(String s) {
		return URLEncoder.encode(s, StandardCharsets.UTF_8);
	}

	private void validateToken(HttpServerRequest request, List<String> forwadedFor, String domainUid,
			boolean internalAuth, String redirectTo, JsonObject jwtToken,
			Consumer<SessionData> handlerSessionConsumer) {
		DecodedJWT accessToken = null;
		try {
			accessToken = JWT.decode(jwtToken.getString("access_token"));
		} catch (JWTDecodeException t) {
			logger.error("Unexpected token endpoint response : {}", jwtToken);
			throw t;
		}

		getExternalCreds(request, accessToken).ifPresent(creds -> {
			AuthProvider prov = new AuthProvider(vertx, domainUid, internalAuth);
			createSession(request, prov, forwadedFor, creds, redirectTo, handlerSessionConsumer);
		});
	}

	private Optional<ExternalCreds> getExternalCreds(HttpServerRequest request, DecodedJWT accessToken) {
		Claim email = accessToken.getClaim("email");
		if (email.isMissing() || email.isNull()) {
			error(request, new ServerFault("Invalid access token: no email claim", ErrorCode.FORBIDDEN));
			return Optional.empty();
		}

		ExternalCreds creds = new ExternalCreds();
		creds.setLoginAtDomain(email.asString());

		Claim bmUid = accessToken.getClaim("bmuid");
		if (!bmUid.isMissing() && !bmUid.isNull()) {
			creds.setUserUid(bmUid.asString());
		}

		return Optional.of(creds);
	}
}
