/* BEGIN LICENSE
  * Copyright © Blue Mind SAS, 2012-2024
  *
  * This file is part of BlueMind. BlueMind is a messaging and collaborative
  * solution.
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of either the GNU Affero General Public License as
  * published by the Free Software Foundation (version 3 of the License).
  *
  * This program is distributed in the hope that it will be useful,
  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  *
  * See LICENSE.txt
  * END LICENSE
  */
package net.bluemind.core.tx.wrapper;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.SQLException;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;

import javax.sql.DataSource;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.github.benmanes.caffeine.cache.Cache;

import net.bluemind.core.tx.wrapper.internal.ActiveTxContext;
import net.bluemind.core.tx.wrapper.internal.ActiveTxContext.NestedException;
import net.bluemind.core.tx.wrapper.internal.TxAwareCaffeineCache;
import net.bluemind.system.state.provider.IStateProvider.CloningState;
import net.bluemind.system.state.provider.StateProvider;

public class TxEnabler {

	private TxEnabler() {

	}

	@SuppressWarnings("serial")
	public static class TxFault extends RuntimeException {
		public TxFault(Throwable t) {
			super(t);
		}

		public TxFault(String string) {
			super(string);
		}
	}

	private static final Logger logger = LoggerFactory.getLogger(TxEnabler.class);

	private static final Set<ActiveTxContext> ACTIVE_CONTEXTS = ConcurrentHashMap.newKeySet();
	private static final ThreadLocal<ActiveTxContext> CONTEXTS = ThreadLocal.withInitial(TxEnabler::createContext);
	private static final Method GET_CONNECTION;
	private static final Method EQUALS;
	private static final Method HASH_CODE;

	static {
		try {
			GET_CONNECTION = DataSource.class.getMethod("getConnection");
			EQUALS = Object.class.getMethod("equals", Object.class);
			HASH_CODE = Object.class.getMethod("hashCode");
		} catch (NoSuchMethodException nsm) {
			throw new TxFault("DataSource does not have mandatory method: " + nsm.getMessage());
		}
	}

	private static final ActiveTxContext createContext() {
		ActiveTxContext ctx = new ActiveTxContext();
		if (!Thread.currentThread().isVirtual()) {
			ACTIVE_CONTEXTS.add(ctx);
		}
		return ctx;
	}

	public static final void shutdown() {
		int active = ACTIVE_CONTEXTS.size();
		if (active > 0) {
			logger.info("Shutdown with {}", active);

			ACTIVE_CONTEXTS.forEach(c -> {
				if (c.inTransaction()) {
					logger.error("{} still in transaction", c);
					c.abort();
				}
			});
			ACTIVE_CONTEXTS.clear();
		} else {
			logger.info("Clean shutdown without active tx contexts ({})", ACTIVE_CONTEXTS);
		}
	}

	/**
	 * This is called in early BM start by JdbcActivator, you should not call that
	 * by yourself.
	 * 
	 * @param ds the hikari data source
	 * @return
	 */

	public static DataSource wrap(DataSource ds) {
		if (ds == null) {
			return null;
		}

		if (Proxy.isProxyClass(ds.getClass())) {
			return ds;
		}

		return (DataSource) Proxy.newProxyInstance(ds.getClass().getClassLoader(), new Class<?>[] { DataSource.class },
				new InvocationHandler() {
					@Override
					public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
						if (method.equals(GET_CONNECTION)) {
							return CONTEXTS.get().enlistedConnection(ds);
						}
						if (method.equals(HASH_CODE)) {
							return System.identityHashCode(proxy);
						}
						if (method.equals(EQUALS)) {
							return proxy == args[0];
						}
						return method.invoke(ds, args);
					}

				});
	}

	public static <K, V> Cache<K, V> wrap(Cache<K, V> existingCache) {
		return new TxAwareCaffeineCache<>(existingCache);
	}

	private static final ActiveTxContext begin() {
		return CONTEXTS.get().begin();
	}

	public static boolean isInTransaction() {
		return CONTEXTS.get().inTransaction();
	}

	/**
	 * This will run on ROLLBACK of the active transaction. This is a NOOP when not
	 * in a transaction
	 * 
	 * For example, when you put something in a cache, you should register a
	 * recovery action to invalidate your stuff on rollback.
	 * 
	 * @param r
	 */
	public static final void recoveryAction(Runnable r) {
		CONTEXTS.get().onRollback(r);
	}

	/**
	 * This will run <em>after</em> COMMIT of the active transaction. This runs
	 * <code>r</code> instantly on the calling thread when not in a transaction.
	 * 
	 * @param r
	 */
	public static final void durableStorageAction(Runnable r) {
		CONTEXTS.get().onCommit(r);
	}

	/**
	 * Calling this will alter the behaviour of our {@link DataSource}.
	 * 
	 * <p>
	 * A transaction is bound to the executing thread. The context will NOT be
	 * propagated over eventbus or through executeBlocking calls.
	 * 
	 * <p>
	 * Subsequent calls to {@link DataSource#getConnection()} will return the same
	 * connection with autocommit disabled.
	 * 
	 * <p>
	 * <code>COMMIT</code> or <code>ROLLBACK</code> will be handled automatically.
	 * If the given runnable succeeds, the transaction will commit. Otherwise, a
	 * <code>ROLLBACK</code> is performed.
	 * 
	 * <p>
	 * Closing the connection inside the runnable will stage it until the end of the
	 * execution. The connection is returned to the pool once the atomically
	 * execution ends.
	 * 
	 * 
	 * @param r the code to run with a transactional context
	 */
	public static final void atomically(Runnable r) {
		if (StateProvider.state() == CloningState.CLONING) {
			try {
				r.run();
			} catch (Throwable t) { // NOSONAR
				if (t instanceof RuntimeException re) {
					throw re;
				} else {
					throw new TxFault(t);
				}
			}
			return;
		}

		ActiveTxContext ctx = begin();
		try {
			r.run();
			ctx.commit();
		} catch (Throwable t) { // NOSONAR
			onFailure(ctx, t);
		} finally {
			tryCleanup(ctx);
		}
	}

	/**
	 * Documentation is on {@link TxEnabler#atomically(Runnable)}
	 * 
	 * @param <T>
	 * @param r
	 * @return
	 */
	public static final <T> T atomically(Callable<T> r) {
		if (StateProvider.state() == CloningState.CLONING) {
			try {
				return r.call();
			} catch (Throwable t) { // NOSONAR
				if (t instanceof RuntimeException re) {
					throw re;
				} else {
					throw new TxFault(t);
				}
			}
		}

		ActiveTxContext ctx = begin();
		try {
			T ret = r.call();
			ctx.commit();
			return ret;
		} catch (Throwable t) { // NOSONAR
			onFailure(ctx, t);
			return null; // unreachable, onFailure throws
		} finally {
			tryCleanup(ctx);
		}
	}

	private static void tryCleanup(ActiveTxContext ctx) {
		if (!ctx.inTransaction()) {
			ACTIVE_CONTEXTS.remove(ctx);
			CONTEXTS.remove();
		}
	}

	private static void onFailure(ActiveTxContext ctx, Throwable t) {
		Throwable unwrap = t;
		if (t instanceof NestedException ne) {
			unwrap = ne.getCause();
		}
		try {
			ctx.rollbackFor(unwrap);
		} catch (SQLException e) {
			throw new TxFault(e);
		}
		if (unwrap instanceof RuntimeException re) {
			throw re;
		} else {
			throw new TxFault(unwrap);
		}
	}

}
