/* BEGIN LICENSE
 * Copyright © Blue Mind SAS, 2012-2024
 *
 * This file is part of BlueMind. BlueMind is a messaging and collaborative
 * solution.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of either the GNU Affero General Public License as
 * published by the Free Software Foundation (version 3 of the License).
 *
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 *
 * See LICENSE.txt
 * END LICENSE
*/
package net.bluemind.keydb.sessiondata;

import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.lettuce.core.RedisClient;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
import io.lettuce.core.pubsub.api.sync.RedisPubSubCommands;
import io.vertx.core.json.DecodeException;
import io.vertx.core.json.JsonObject;
import net.bluemind.keydb.common.ClientProvider;
import net.bluemind.keydb.common.KeydbBootstrapNetAddress;
import net.bluemind.keydb.sessiondata.cache.SessionDataCacheFactory;

public class SessionDataStore {
	private static final Logger logger = LoggerFactory.getLogger(SessionDataStore.class);
	public static final CompletableFuture<RedisPubSubCommands<String, String>> storeStart = new CompletableFuture<>();
	private StatefulRedisPubSubConnection<String, String> connection;
	private RedisPubSubCommands<String, String> commands;

	private static final String JWT_SESSION_STATE = "session_state";

	private static final String SESSIONID_REFRESH_STORE = "openid:refresh";
	private static final String SESSIONID_REFRESHED_STORE = SESSIONID_REFRESH_STORE + ":inflight";
	public static final long SESSIONID_REFRESH_PERIOD = TimeUnit.MINUTES.toMillis(30);

	public static final String DATA_VALUE_HOLDER = "data:sid:";
	private static final String DATA_VALUE_SESSIONSTATE_HOLDER = "data:ssid:";

	private final SessionDataCacheFactory.SessionDataCache sessionDataCache = new SessionDataCacheFactory().get();

	private SessionDataStore() {
		Thread.ofPlatform().name("bm-keydb-data-sessions-connect").start(() -> {
			do {
				try {
					RedisClient redisClient = ClientProvider.newClient();
					connection = redisClient.connectPubSub();
					commands = connection.sync();
					storeStart.complete(commands);
					logger.info("Keydb connection setup completed ({})", KeydbBootstrapNetAddress.getKeydbIP());
				} catch (Exception e) {
					logger.error(e.getMessage(), e);
					try {
						Thread.sleep(Duration.ofSeconds(5));
					} catch (InterruptedException ie) {
						Thread.currentThread().interrupt();
						storeStart.completeExceptionally(ie);
						break;
					}
				}
			} while (commands == null);
		});
	}

	private static class LazyHolder {
		static final SessionDataStore INSTANCE = new SessionDataStore();
	}

	public static SessionDataStore get() {
		return LazyHolder.INSTANCE;
	}

	public RedisPubSubCommands<String, String> getCommands() {
		try {
			return storeStart.get();
		} catch (InterruptedException ie) {
			Thread.currentThread().interrupt();
			throw new KeyDbConnectionNotAvalaible(ie);
		} catch (ExecutionException e) {
			throw new KeyDbConnectionNotAvalaible(e);
		}
	}

	@SuppressWarnings("serial")
	public static class KeyDbConnectionNotAvalaible extends RuntimeException {
		KeyDbConnectionNotAvalaible(Throwable t) {
			super("NO: KeydbSessionsClient is not available: ", t);
		}
	}

	private String valueHolder(String sid) {
		return DATA_VALUE_HOLDER + sid;
	}

	private String sessionStateValueHolder(String sessionStateId) {
		return DATA_VALUE_SESSIONSTATE_HOLDER + sessionStateId;
	}

	public SessionData getIfPresent(String sid) {
		SessionData sessionData = sessionDataCache.getIfPresent(sid);
		if (sessionData != null) {
			return sessionData;
		}

		String strValue = getCommands().get(valueHolder(sid));
		if (strValue != null) {
			try {
				sessionData = SessionData.fromJson(new JsonObject(strValue));
			} catch (DecodeException de) {
				return null;
			}
			sessionDataCache.put(sid, sessionData);
		}

		return sessionData;
	}

	public SessionData getFromSessionState(String sessionStateId) {
		String sid = getCommands().get(sessionStateValueHolder(sessionStateId));
		return sid != null ? getIfPresent(sid) : null;
	}

	public boolean exists(String sid) {
		return getCommands().exists(valueHolder(sid)) > 0;
	}

	public void put(SessionData sessionData) {
		sessionData = openIdRefreshQueue(sessionData);
		getCommands().set(valueHolder(sessionData.authKey), SessionData.toJson(sessionData).encode());
		sessionDataCache.invalidate(sessionData.authKey);

		setSessionStateToSessionId(sessionData.authKey, sessionData.jwtToken);
	}

	public void updateSessionData(SessionData sessionData) {
		if (exists(sessionData.authKey)) {
			getCommands().set(valueHolder(sessionData.authKey), SessionData.toJson(sessionData).encode());
			sessionDataCache.invalidate(sessionData.authKey);

			setSessionStateToSessionId(sessionData.authKey, sessionData.jwtToken);
		}
	}

	private void setSessionStateToSessionId(String sessionId, JsonObject jwtToken) {
		if (jwtToken == null) {
			return;
		}

		String sessionStateValueHolder = sessionStateValueHolder((String) jwtToken.getValue(JWT_SESSION_STATE));
		if (getCommands().exists(sessionStateValueHolder) == 0) {
			getCommands().set(sessionStateValueHolder, sessionId);
		}
	}

	private SessionData openIdRefreshQueue(SessionData sessionData) {
		if (sessionData.authKey == null || sessionData.jwtToken == null || !sessionData.internalAuth) {
			if (logger.isDebugEnabled()) {
				logger.debug("Do not add session {} to refresh queue: not an internal OpenID token",
						sessionData.authKey);
			}

			return sessionData;
		}

		long refreshStamp = sessionData.openIdRefreshStamp;
		if (refreshStamp == -1) {
			refreshStamp = sessionData.createStamp + SESSIONID_REFRESH_PERIOD;
		}

		if (logger.isDebugEnabled()) {
			logger.debug("Enqueue session {} to refresh queue", sessionData.authKey);
		}

		getCommands().rpush(SESSIONID_REFRESH_STORE, sessionData.authKey);
		return sessionData.setOpenIdRefreshStamp(refreshStamp);
	}

	public void invalidate(String sid) {
		Optional.ofNullable(getIfPresent(sid)).map(sessionData -> sessionData.jwtToken)
				.map(jwtToken -> (String) jwtToken.getValue(JWT_SESSION_STATE)).map(this::sessionStateValueHolder)
				.ifPresent(ssid -> getCommands().del(ssid));

		getCommands().del(valueHolder(sid));
		sessionDataCache.invalidate(sid);

		removeRefreshSessionId(sid);
	}

	public void requeueRefreshSessionId() {
		int requeued = 0;
		String sessionId;
		while ((sessionId = SessionDataStore.get().getCommands().rpoplpush(SESSIONID_REFRESHED_STORE,
				SESSIONID_REFRESH_STORE)) != null) {
			if (logger.isDebugEnabled()) {
				logger.debug("Requeue session {}", sessionId);
			}

			requeued++;
		}

		if (logger.isDebugEnabled()) {
			logger.debug("Requeue {} session ID to refresh queue", requeued);
		}
	}

	public String getSessionIdToRefresh() {
		return getCommands().rpoplpush(SESSIONID_REFRESH_STORE, SESSIONID_REFRESHED_STORE);
	}

	public void removeRefreshSessionId(String sessionId) {
		getCommands().lrem(SESSIONID_REFRESHED_STORE, 1, sessionId);
	}
}
