/* BEGIN LICENSE
  * Copyright © Blue Mind SAS, 2012-2024
  *
  * This file is part of BlueMind. BlueMind is a messaging and collaborative
  * solution.
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of either the GNU Affero General Public License as
  * published by the Free Software Foundation (version 3 of the License).
  *
  * This program is distributed in the hope that it will be useful,
  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  *
  * See LICENSE.txt
  * END LICENSE
  */
package net.bluemind.core.tx.wrapper.internal;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;

import javax.sql.DataSource;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ActiveTxContext {

	private static final AtomicLong XID = new AtomicLong();

	private enum Status {
		NOT_IN_TX, ACTIVE;
	}

	private static interface InTxCon extends Connection {

		void backToPool();

	}

	@SuppressWarnings("serial")
	private static class EnlistException extends RuntimeException {

		public EnlistException(SQLException e) {
			super(e);
		}

		public EnlistException(String string) {
			super(string);
		}

	}

	@SuppressWarnings("serial")
	public static class NestedException extends RuntimeException {

		public NestedException(String msg, Throwable e) {
			super(msg, e);
		}

	}

	private static final Method closeMethod = loadClose();
	private static final Method backToPoolMethod = loadBackToPool();
	private static final Set<Method> rollbackMethods = loadRollback();
	private static final Logger logger = LoggerFactory.getLogger(ActiveTxContext.class);

	private final Map<DataSource, InTxCon> activeConnections = new HashMap<>();

	/**
	 * Not an array dequeue because we don't want it to keep growing in thread local
	 * space
	 */
	private final Deque<Runnable> recoveryActions = new LinkedList<>();
	private final Deque<Runnable> commitActions = new LinkedList<>();
	private final long txId = XID.incrementAndGet();
	private Status status = Status.NOT_IN_TX;
	private int nestingLvl = 0;
	private AtomicReference<Throwable> rollbackOnly = new AtomicReference<>();

	public Connection enlistedConnection(DataSource ds) throws SQLException {
		switch (status) {
		case ACTIVE:
			if (Thread.currentThread().isVirtual()) {
				throw new EnlistException(
						"Enlisting a connection for virtual thread is not supported as ThreadLocal is involved");
			} else {
				return enlist(ds);
			}
		case NOT_IN_TX:
			return ds.getConnection();
		default:
			throw new EnlistException("Unsupported status " + status);
		}
	}

	private Connection enlist(DataSource ds) throws SQLException {
		try {
			return activeConnections.computeIfAbsent(ds, pool -> {
				try {
					Connection fromPool = pool.getConnection();
					fromPool.setAutoCommit(false);
					return (InTxCon) Proxy.newProxyInstance(getClass().getClassLoader(),
							new Class<?>[] { InTxCon.class }, new InvocationHandler() {
								@Override
								public Object invoke(Object proxy, Method method, Object[] args) throws SQLException {
									try {
										if (closeMethod.equals(method)) {
											logger.debug("{} keep enlisted in TX", fromPool);
											return null;
										} else if (backToPoolMethod.equals(method)) {
											logger.debug("{} goes back to pool", fromPool);
											fromPool.close();
											return null;
										} else if (rollbackOnly.get() != null && !rollbackMethods.contains(method)) {
											throw rollbackOnly.get();
										} else {
											return method.invoke(fromPool, args);
										}
									} catch (InvocationTargetException t) {
										Throwable targetException = t.getTargetException();
										if (targetException != null && targetException instanceof SQLException sqle) {
											throw sqle;
										} else {
											throw new SQLException("Proxy method invokation failed: " + t.getMessage());
										}
									} catch (Throwable t) { // NOSONAR: yes that's what we want
										logger.error("Unknown error during invoke", t);
										throw new SQLException("invoke failed: " + t.getMessage());
									}
								}
							});
				} catch (SQLException e) {
					throw new EnlistException(e);
				}
			});
		} catch (EnlistException e) {
			throw (SQLException) e.getCause();
		}
	}

	private static Method loadClose() {
		try {
			return Connection.class.getMethod("close");
		} catch (NoSuchMethodException e) {
			throw new EnlistException("no close method ?");
		}
	}

	private static Method loadBackToPool() {
		try {
			return InTxCon.class.getMethod("backToPool");
		} catch (NoSuchMethodException e) {
			throw new EnlistException("no backToPool method ?");
		}
	}

	private static Set<Method> loadRollback() {
		try {
			return Set.of(Connection.class.getMethod("rollback"),
					Connection.class.getMethod("setAutoCommit", boolean.class));
		} catch (NoSuchMethodException e) {
			e.printStackTrace();
			throw new EnlistException("no rollback method ?");
		}
	}

	public ActiveTxContext begin() {
		if (status == Status.NOT_IN_TX) {
			status = Status.ACTIVE;
			logger.trace("BEGIN 😧");
			nestingLvl = 1;
		} else {
			nestingLvl++;
			if (logger.isDebugEnabled()) {
				String nestLvl = IntStream.range(0, nestingLvl).mapToObj(i -> "🪆").reduce("", (a, b) -> a + b);
				logger.debug(" NESTING TX {}", nestLvl);
			}
		}
		return this;
	}

	/**
	 * @return true if the transaction is over, false if we are part of a parent
	 *         transaction
	 * @throws SQLException
	 */
	public boolean commit() throws SQLException {
		if (--nestingLvl == 0) {
			if (rollbackOnly.get() != null) {
				System.err.println("Force rollback for " + rollbackOnly.get());
				rollback0();
			} else {
				commit0();
			}
		} else {
			if (nestingLvl < 0) {
				throw new SQLException("TX error: nesting level is less than 0");
			}
			if (logger.isDebugEnabled()) {
				String nestLvl = IntStream.range(0, nestingLvl).mapToObj(i -> "🪆").reduce("", (a, b) -> a + b);
				logger.debug(" UNNEST TX {}", nestLvl);
			}
		}
		return status == Status.NOT_IN_TX;
	}

	/**
	 * @param t
	 * @return true if the transaction is over, false if we are part of a parent
	 *         transaction
	 * @throws SQLException
	 */
	public boolean rollbackFor(Throwable t) throws SQLException {
		if (--nestingLvl == 0) {
			rollback0();
			return true;
		} else {
			rollbackOnly.compareAndSet(null, t);
			logger.debug("Nested rollback at lvl {} 🍆", nestingLvl + 1);
			if (t instanceof RuntimeException re) {
				throw re;
			} else {
				throw new NestedException("Rollback 🍆 to nesting level " + nestingLvl, t);
			}
		}
	}

	private void commit0() throws SQLException {
		List<SQLException> commitErrors = new ArrayList<>(2);
		try {
			for (var dsAndCon : activeConnections.entrySet()) {
				InTxCon txCon = dsAndCon.getValue();
				boolean endActionFailed = runOrRecord(txCon::commit, commitErrors);
				runOrRecord(() -> txCon.setAutoCommit(true), commitErrors);
				putConBack(txCon, endActionFailed);
			}
			logger.trace("COMMITED 👍🏻");
		} finally {
			resetState();
		}

		if (commitErrors.isEmpty()) {
			recoveryActions.clear();
			runCommitActions();
		} else {
			// This is a commit error. E.g. deferrable constraints
			rollback0();
			throw commitErrors.getFirst();
		}
	}

	private interface MayFail {
		void tryAction() throws Exception; // NOSONAR
	}

	private boolean runOrRecord(MayFail f, List<SQLException> errorStore) {
		try {
			f.tryAction();
			return false;
		} catch (Throwable t) { // NOSONAR
			if (t instanceof SQLException se) {
				errorStore.add(se);
			} else {
				logger.error("Wrapping to SQLException", t);
				errorStore.add(new SQLException(t));
			}
			return true;
		}
	}

	private void resetState() {
		status = Status.NOT_IN_TX;
		nestingLvl = 0;
		activeConnections.clear();
		rollbackOnly.set(null);
	}

	public void rollback0() throws SQLException {
		List<SQLException> rollbackErrors = new ArrayList<>(2);
		try {
			for (var dsAndCon : activeConnections.entrySet()) {
				InTxCon txCon = dsAndCon.getValue();
				boolean endActionFailed = runOrRecord(txCon::rollback, rollbackErrors);
				runOrRecord(() -> txCon.setAutoCommit(true), rollbackErrors);
				putConBack(txCon, endActionFailed);
			}
			logger.trace("ROLLED BACK 🍆");
		} catch (Throwable t) { // NOSONAR: Yes, that's what we want
			logger.error("Unable to rollback", t);
			throw t;
		} finally {
			runRecoveryActions();
			commitActions.clear();
			resetState();
		}

		if (!rollbackErrors.isEmpty()) {
			throw rollbackErrors.getFirst();
		}
	}

	private void putConBack(InTxCon txCon, boolean endActionFailed) {
		try {
			if (endActionFailed) {
				logger.warn("Put connection back after failed commit/rollback of {}", this);
			}
			txCon.backToPool();
		} catch (Throwable t) { // NOSONAR: We want Throwable here
			logger.error("Can't put back in pool after failed action, EXIT", t);
		}
	}

	private void runRecoveryActions() {
		int reco = 0;
		Runnable recover = null;
		while ((recover = recoveryActions.poll()) != null) {
			try {
				recover.run();
				reco++;
			} catch (Throwable t) {
				logger.error("Could not recover with {}", recover, t);
			}
		}
		if (reco > 0) {
			logger.info("Ran {} recovery actions before rollback", reco);
		}
	}

	public void onRollback(Runnable r) {
		if (inTransaction()) {
			recoveryActions.add(r);
		}
	}

	private void runCommitActions() {
		int events = 0;
		Runnable onCommit = null;
		while ((onCommit = commitActions.poll()) != null) {
			try {
				onCommit.run();
				events++;
			} catch (Throwable t) {
				logger.error("on-commit action failed {}", onCommit, t);
			}
		}
		if (events > 0) {
			logger.debug("Ran {} commit action(s) after commit", events);
		}
	}

	public void onCommit(Runnable r) {
		if (status == Status.ACTIVE) {
			commitActions.add(r);
		} else {
			r.run();
		}
	}

	public boolean inTransaction() {
		return status == Status.ACTIVE;
	}

	@Override
	public String toString() {
		return "TxContext{xid: " + txId + ", s: " + status + ", depth: " + nestingLvl + "}";
	}

	public void abort() {
		logger.error("ABORT {}", this);
		rollbackOnly.compareAndSet(null, new SQLException("ABORTED transaction"));
	}

}
