/* BEGIN LICENSE
  * Copyright © Blue Mind SAS, 2012-2025
  *
  * This file is part of BlueMind. BlueMind is a messaging and collaborative
  * solution.
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of either the GNU Affero General Public License as
  * published by the Free Software Foundation (version 3 of the License).
  *
  * This program is distributed in the hope that it will be useful,
  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  *
  * See LICENSE.txt
  * END LICENSE
  */
package net.bluemind.memory.pool.mmap;

import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.nio.file.Files;
import java.nio.file.Path;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.bluemind.memory.pool.mmap.FdCreator.FdAndSizes;

public class MmapSupport {

	private static final Logger logger = LoggerFactory.getLogger(MmapSupport.class);

	public enum Advice {
		REMOVE(9),

		COLD(20),

		PAGEOUT(21),

		DONTNEED(4);

		private final int value;

		private Advice(int value) {
			this.value = value;
		}
	}

	// Constantes mmap
	private static final int PROT_READ = 0x1;
	private static final int PROT_WRITE = 0x2;
	private static final int MAP_SHARED = 0x01;
	@SuppressWarnings("unused")
	private static final int MAP_PRIVATE = 0x02;
	private static final long MAP_FAILED = -1L;

	// Constantes pour open()
	private static final int O_RDWR = 0x0002;
	private static final int O_CREAT = 0x0040;
	private static final int S_IRUSR = 0x0100;
	private static final int S_IWUSR = 0x0080;

	// memfd_create
	public static final int MFD_CLOEXEC = 0x0001;

	// Linker & symbols
	private static final Linker LINKER = Linker.nativeLinker();
	private static final MethodHandle madvise;
	private static final MethodHandle strerror;
	private static final MethodHandle errno_location;
	private static final MethodHandle mmap_handle;
	private static final MethodHandle munmap_handle;
	private static final MethodHandle open_handle;
	private static final MethodHandle close_handle;
	private static final MethodHandle ftruncate_handle;
	private static final MethodHandle memfd_handle;

	static {
		try {
			// Recherche de la fonction madvise
			SymbolLookup stdlib = LINKER.defaultLookup();
			MemorySegment madviseSymbol = stdlib.find("madvise")
					.orElseThrow(() -> new RuntimeException("madvise not found"));

			// Signature: int madvise(void *addr, size_t length, int advice)
			FunctionDescriptor madviseDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, // retour int
					ValueLayout.ADDRESS, // void *addr
					ValueLayout.JAVA_LONG, // size_t length
					ValueLayout.JAVA_INT // int advice
			);

			madvise = LINKER.downcallHandle(madviseSymbol, madviseDesc);

			// Recherche de la fonction strerror
			MemorySegment strerrorSymbol = stdlib.find("strerror")
					.orElseThrow(() -> new RuntimeException("strerror not found"));

			// Signature: char *strerror(int errnum)
			FunctionDescriptor strerrorDesc = FunctionDescriptor.of(ValueLayout.ADDRESS, // char *
					ValueLayout.JAVA_INT // int errnum
			);

			strerror = LINKER.downcallHandle(strerrorSymbol, strerrorDesc);

			// Recherche de __errno_location (Linux) pour obtenir errno
			MemorySegment errnoSymbol = stdlib.find("__errno_location")
					.orElseThrow(() -> new RuntimeException("__errno_location not found"));

			// Signature: int *__errno_location(void)
			FunctionDescriptor errnoDesc = FunctionDescriptor
					.of(ValueLayout.ADDRESS.withTargetLayout(ValueLayout.JAVA_INT) // int *
					);
			errno_location = LINKER.downcallHandle(errnoSymbol, errnoDesc);

			// mmap: void *mmap(void *addr, size_t length, int prot, int flags, int fd,
			// off_t offset)
			MemorySegment mmapSymbol = stdlib.find("mmap").orElseThrow(() -> new RuntimeException("mmap not found"));
			FunctionDescriptor mmapDesc = FunctionDescriptor.of(ValueLayout.ADDRESS, // void *
					ValueLayout.ADDRESS, // void *addr
					ValueLayout.JAVA_LONG, // size_t length
					ValueLayout.JAVA_INT, // int prot
					ValueLayout.JAVA_INT, // int flags
					ValueLayout.JAVA_INT, // int fd
					ValueLayout.JAVA_LONG // off_t offset
			);
			mmap_handle = LINKER.downcallHandle(mmapSymbol, mmapDesc);

			// munmap: int munmap(void *addr, size_t length)
			MemorySegment munmapSymbol = stdlib.find("munmap")
					.orElseThrow(() -> new RuntimeException("munmap not found"));
			FunctionDescriptor munmapDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS,
					ValueLayout.JAVA_LONG);
			munmap_handle = LINKER.downcallHandle(munmapSymbol, munmapDesc);

			// open: int open(const char *pathname, int flags, mode_t mode)
			MemorySegment openSymbol = stdlib.find("open").orElseThrow(() -> new RuntimeException("open not found"));
			FunctionDescriptor openDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS,
					ValueLayout.JAVA_INT, ValueLayout.JAVA_INT);
			open_handle = LINKER.downcallHandle(openSymbol, openDesc);

			// close: int close(int fd)
			MemorySegment closeSymbol = stdlib.find("close").orElseThrow(() -> new RuntimeException("close not found"));
			FunctionDescriptor closeDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT);
			close_handle = LINKER.downcallHandle(closeSymbol, closeDesc);

			// ftruncate: int ftruncate(int fd, off_t length)
			MemorySegment ftruncateSymbol = stdlib.find("ftruncate")
					.orElseThrow(() -> new RuntimeException("ftruncate not found"));
			FunctionDescriptor ftruncateDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT,
					ValueLayout.JAVA_LONG);
			ftruncate_handle = LINKER.downcallHandle(ftruncateSymbol, ftruncateDesc);

			// int memfd_create(const char *name, unsigned int flags)
			MemorySegment memfdCreateSymbol = stdlib.find("memfd_create")
					.orElseThrow(() -> new RuntimeException("memfd_create not found"));
			FunctionDescriptor memfdDesc = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS,
					ValueLayout.JAVA_INT);
			memfd_handle = LINKER.downcallHandle(memfdCreateSymbol, memfdDesc);

		} catch (Throwable e) { // NOSONAR
			throw new ExceptionInInitializerError(e);
		}
	}

	@SuppressWarnings("serial")
	private static class MmapSupportRTE extends RuntimeException {

		public MmapSupportRTE(String string, Throwable e) {
			super(string, e);
		}

		public MmapSupportRTE(Throwable t) {
			super(t);
		}

		public static void propagate(Throwable t) {
			if (t instanceof RuntimeException rte) {
				throw rte;
			} else {
				throw new MmapSupportRTE(t);
			}
		}

		public static <T> T propagate(Throwable t, T useless) { // NOSONAR unused
			if (t instanceof RuntimeException rte) {
				throw rte;
			} else {
				throw new MmapSupportRTE(t);
			}
		}

	}

	public static record MappedSegment(MemorySegment seg, FdAndSizes fd, SizeRecorder sizeRec) {

		public MappedSegment {
			sizeRec.updateLockedIn(fd.locked());
			sizeRec.updateFlushable(fd.flushable());
		}

		public void cleanup() {
			try {
				// https://github.com/torvalds/linux/blob/0c3836482481200ead7b416ca80c68a29cfdaabd/mm/madvise.c#L843
				madvise(seg, Advice.DONTNEED);

				munmap(seg);
				close_handle.invoke(fd.fd());
				sizeRec.updateLockedIn(0 - fd.locked());
				sizeRec.updateFlushable(0 - fd.flushable());
			} catch (Throwable e) {
				MmapSupportRTE.propagate(e);
			}
		}
	}

	private static int getErrno() {
		try {
			MemorySegment errnoPtr = (MemorySegment) errno_location.invoke();
			return errnoPtr.get(ValueLayout.JAVA_INT, 0);
		} catch (Throwable e) {
			throw new MmapSupportRTE("Failed to get errno", e);
		}
	}

	private static String getErrorMessage(int errno) {
		try {
			MemorySegment errorStrPtr = (MemorySegment) strerror.invoke(errno);
			return errorStrPtr.reinterpret(Long.MAX_VALUE).getString(0);
		} catch (Throwable e) {
			return "Unknown error: " + errno + " (" + e.getMessage() + ")";
		}
	}

	public static FdCreator realFile() {
		return (path, arena, size) -> {
			MemorySegment pathSegment = arena.allocateFrom(path.toString());
			int fd = (int) open_handle.invoke(pathSegment, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
			ioExceptionOnError("open", fd);
			return new FdAndSizes(fd, 0L, size);
		};
	}

	public static FdCreator memfd() {
		return (path, arena, size) -> {
			MemorySegment pathSegment = arena.allocateFrom(path.getFileName().toString());
			int fd = (int) memfd_handle.invoke(pathSegment, MFD_CLOEXEC);
			ioExceptionOnError("memfd_create", fd);
			return new FdAndSizes(fd, size, 0L);
		};
	}

	public static MappedSegment mmap(Arena arena, Path backingFile, FdCreator creator, SizeRecorder sizeRecorder,
			long size) {
		try {

			FdAndSizes fdWithSize = creator.newFd(backingFile, arena, size);
			int fd = fdWithSize.fd();
			try {
				ftruncate(fd, size);

				MemorySegment addr = MemorySegment.NULL;
				MemorySegment result = (MemorySegment) mmap_handle.invoke(addr, size, PROT_READ | PROT_WRITE,
						MAP_SHARED, fd, 0L);

				if (result.address() == MAP_FAILED) {
					int errno = getErrno();
					String errorMsg = getErrorMessage(errno);
					throw new IOException("mmap failed: " + errorMsg + " (errno=" + errno + ")");
				}

				MemorySegment mapped = result.reinterpret(size, arena, null);

				MappedSegment prepared = new MappedSegment(mapped, fdWithSize, sizeRecorder);
				if (logger.isInfoEnabled()) {
					logger.info("Mapped {} MB, sizes: {}", size / 1024 / 1024, sizeRecorder);
				}
				return prepared;

			} finally {
				Files.deleteIfExists(backingFile);
			}

		} catch (Throwable e) {
			return MmapSupportRTE.propagate(e, null);
		}
	}

	private static void ftruncate(int fd, long maxSize) throws Throwable {
		int truncateResult = (int) ftruncate_handle.invoke(fd, maxSize);
		ioExceptionOnError("ftruncate", truncateResult);
	}

	private static void munmap(MemorySegment mapping) {
		try {
			if (mapping == null || mapping == MemorySegment.NULL) {
				logger.warn("Attempted to munmap null segment");
				return;
			}

			long size = mapping.byteSize();

			int result = (int) munmap_handle.invoke(mapping, size);
			ioExceptionOnError("munmap", result);

		} catch (Throwable e) {
			MmapSupportRTE.propagate(e);
		}
	}

	public static void madvise(MemorySegment segment, Advice flag) {
		try {
			int result = (int) madvise.invoke(segment, segment.byteSize(), flag.value);
			ioExceptionOnError("madvise", result);
		} catch (Throwable e) {
			MmapSupportRTE.propagate(e);
		}
	}

	private static void ioExceptionOnError(String call, int result) throws IOException {
		if (result == -1) {
			int errno = getErrno();
			String errorMsg = getErrorMessage(errno);
			throw new IOException("Failed to '" + call + "'" + errorMsg + " (errno=" + errno + ")");
		}
	}

}
