GlobalExceptionHandler.kt

package com.example.templateproject.web.exception

import com.example.templateproject.api.dto.ErrorDTO
import com.example.templateproject.client.exception.ExternalServiceException
import com.example.templateproject.client.exception.ExternalServiceExceptionHandler
import com.example.templateproject.core.exception.BaseException
import com.example.templateproject.core.exception.ExecutionTimeoutException
import com.example.templateproject.web.metrics.ExceptionMetrics
import org.springframework.beans.factory.annotation.Value
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpStatus
import org.springframework.http.HttpStatusCode
import org.springframework.http.ResponseEntity
import org.springframework.validation.FieldError
import org.springframework.web.bind.MethodArgumentNotValidException
import org.springframework.web.bind.annotation.ControllerAdvice
import org.springframework.web.bind.annotation.ExceptionHandler
import org.springframework.web.client.ResourceAccessException
import org.springframework.web.context.request.WebRequest
import org.springframework.web.servlet.mvc.method.annotation.ResponseEntityExceptionHandler
import java.util.concurrent.ExecutionException

@ControllerAdvice
class GlobalExceptionHandler(
    @param:Value($$"${app.stack.trace.enabled:false}") private val printStackTraceEnabled: Boolean,
    private val exceptionMetrics: ExceptionMetrics,
    private val externalServiceExceptionHandler: ExternalServiceExceptionHandler,
) : ResponseEntityExceptionHandler() {
    companion object {
        const val UNKNOWN_ERROR_MESSAGE = "Unknown internal server error"
        const val STACK_TRACE_QUERY_PARAMETER_NAME = "stackTrace"
    }

    public override fun handleExceptionInternal(
        ex: java.lang.Exception,
        body: Any?,
        headers: HttpHeaders,
        statusCode: HttpStatusCode,
        request: WebRequest,
    ): ResponseEntity<Any> {
        val responseEntity = handleExceptions(ex, request)
        return ResponseEntity(responseEntity.body as Any, statusCode)
    }

    @ExceptionHandler(value = [Exception::class])
    fun handleExceptions(
        originalException: Exception,
        request: WebRequest,
    ): ResponseEntity<ErrorDTO> {
        val exception = unwrapException(originalException)
        val message = getMessage(exception)
        val details = getDetails(exception)
        val exceptionId = ExceptionIdGenerator.generateExceptionId(exception)
        val logPrefix = generateLogPrefix(exception)

        exceptionMetrics.updateExceptionCounter(exceptionId, exception.javaClass.simpleName)
        logger.error("${logPrefix}ExceptionId: $exceptionId - $exception", exception)

        val stackTrace = getStackTrace(exception, request)
        val body = ErrorDTO(exceptionId, message, details, stackTrace)
        val httpStatus = getHttpStatus(exception)
        return ResponseEntity(body, httpStatus)
    }

    private fun unwrapException(originalException: Exception) =
        when (originalException) {
            is ExecutionException -> originalException.cause as Exception
            else -> originalException
        }

    private fun getMessage(exception: Exception): String =
        when (exception) {
            is MethodArgumentNotValidException -> "Invalid request content."
            else -> exception.message ?: UNKNOWN_ERROR_MESSAGE
        }

    private fun getDetails(exception: Exception): Any? =
        when (exception) {
            is MethodArgumentNotValidException -> {
                val errors: Map<String, List<String>> =
                    exception.bindingResult.allErrors
                        .groupBy { (it as FieldError).field }
                        .mapValues { (_, groupedErrors) -> groupedErrors.map { it.defaultMessage!! } }
                errors
            }
            is ExternalServiceException -> externalServiceExceptionHandler.getDetails(exception)
            is ExecutionTimeoutException -> exception.details
            else -> null
        }

    private fun generateLogPrefix(exception: Exception): String =
        when (exception) {
            is ExternalServiceException -> "[${exception.serviceName}] - "
            else -> ""
        }

    private fun getStackTrace(
        exception: Exception,
        request: WebRequest,
    ): String? = if (printStackTraceEnabled && isStackTraceRequested(request)) exception.stackTraceToString() else null

    private fun isStackTraceRequested(request: WebRequest): Boolean =
        request.getParameterValues(STACK_TRACE_QUERY_PARAMETER_NAME)?.any {
            !it.isNullOrBlank() && it.toBoolean()
        } ?: false

    private fun getHttpStatus(exception: Exception): HttpStatusCode =
        when (exception) {
            is BaseException -> exception.httpStatus
            is MethodArgumentNotValidException -> exception.statusCode
            is ResourceAccessException -> HttpStatus.GATEWAY_TIMEOUT
            else -> HttpStatus.INTERNAL_SERVER_ERROR
        }
}