import os
from typing import List

import processing
from ApplicationConfig import ApplicationConfig
from Layer import (
    getClusterOrSiteClusterLayerName,
    getGridGroup,
    getLayerByKeyword,
    getRimoLayer,
    getWmsLayer,
    setStyle,
)
from LayerMetadataHelper import LayerMetadataHelper
from MessageAggregator import MessageAggregator
from Project import hasRimoData, saveCurrentProject
from ProjectHelper import getExistingProjectFilePath
from ProjectState import ProjectState
from LayerType import LayerType
from qgis.core import (
    QgsCoordinateReferenceSystem,
    QgsCoordinateTransform,
    QgsDefaultValue,
    QgsField,
    QgsMapSettings,
    QgsPrintLayout,
    QgsReadWriteContext,
    QgsVectorFileWriter,
    QgsVectorLayer,
)

from qgis.PyQt.QtXml import QDomDocument
from qgis.PyQt.QtCore import QMetaType
from Result import Result


class ComposerData:
    def __init__(self, region: str, cluster: str, siteCluster: str):
        self.region = region
        self.cluster = cluster
        self.siteCluster = siteCluster


def createPrintComposer(
    printingtype,
    paperFormat,
    printingScale,
    qgisProject,
    qgisInterface,
    appConfig: ApplicationConfig,
    projectState: ProjectState,
    composerData: ComposerData,
    infoMessageAggregator: MessageAggregator,
):
    if not hasRimoData(qgisProject):
        return Result.fail("Please load a rimo project!")

    if getLayerByKeyword("Trenches", qgisProject).isFailure():
        return Result.fail(
            "The layer trenches must be loaded in the current project to use this function!"
        )

    if printingtype not in ["Individual Print", "Atlas", "Atlas_StaticLegend"]:
        return Result.fail(
            f"Composer type {printingtype} is not supported at the moment!"
        )

    createAtlas = "Individual" if printingtype == "Individual Print" else printingtype
    composerName = f"{createAtlas}_{paperFormat}_{printingScale}"

    layoutExists = checkIfLayoutExists(composerName, qgisProject)
    if layoutExists:
        return Result.fail(
            "Print Layout already exists. Check out your Layout Manager!",
        )

    composerTemplateFilePathResult = appConfig.getRimoPrintComposerTemplateFile(
        composerName
    )
    if composerTemplateFilePathResult.isFailure():
        return Result.fail(
            "No print composer template for the chosen scale / format found.",
        )

    # Access GridGroup
    gridGroupResult = getGridGroup(qgisProject)
    if gridGroupResult.isFailure():
        return Result.fail("Grid Group is missing. Can't create Composer")
    gridLayerGroup = gridGroupResult.value

    # Create Grid Layer
    gridLayerValues = getGridLayerValues(paperFormat, printingScale)

    clusterPolygonNameResult = getClusterOrSiteClusterLayerName(
        qgisProject, projectState
    )
    if clusterPolygonNameResult.isFailure():
        return Result.fail(
            f"{clusterPolygonNameResult.message}. Print composer can't be created!"
        )
    clusterPolygonName = clusterPolygonNameResult.value

    polygonLayerResult = getLayerByKeyword(clusterPolygonName, qgisProject)
    if polygonLayerResult.isFailure():
        return polygonLayerResult
    polygonLayer = polygonLayerResult.value

    atlasGridNameResult = createAtlasGrid(
        polygonLayer,
        gridLayerGroup,
        gridLayerValues,
        qgisProject,
        appConfig,
        projectState,
    )
    if atlasGridNameResult.isFailure():
        return atlasGridNameResult
    atlasGridName = atlasGridNameResult.value

    atlasGridResult = getLayerByKeyword(atlasGridName, qgisProject)
    if atlasGridResult.isFailure():
        return atlasGridResult
    atlasGrid = atlasGridResult.value
    styling = projectState.currentStyling

    createNewRimoComposer(
        composerTemplateFilePathResult.value,
        composerName,
        createAtlas,
        printingScale,
        atlasGrid,
        styling,
        qgisProject,
        appConfig,
        qgisInterface,
        projectState,
        composerData,
        infoMessageAggregator,
    )

    projectName = projectState.getProjectName()
    qgisProjectExistsResult = getExistingProjectFilePath(
        appConfig.getProjectDirPath(projectName), projectName
    )
    if qgisProjectExistsResult.isFailure():
        return qgisProjectExistsResult

    saveResult = saveCurrentProject(
        qgisProject,
        qgisProjectExistsResult.value,
        projectState,
        appConfig,
        infoMessageAggregator,
    )

    return saveResult


def createNewRimoComposer(
    composerTemplateFilePath: str,
    composerName: str,
    createAtlas,
    printingScale,
    atlasGrid,
    styling,
    qgisProject,
    appConfig: ApplicationConfig,
    qgisInterface,
    projectState: ProjectState,
    composerData: ComposerData,
    infoMessageAggregator: MessageAggregator,
):
    templateFile = open(composerTemplateFilePath, "r")
    templateContent = templateFile.read()
    templateFile.close()

    # Create Document from File
    myDocument = QDomDocument()
    myDocument.setContent(templateContent, False)

    settings = QgsMapSettings()
    qgisLayout = QgsPrintLayout(qgisProject)

    qgisLayout.setName(composerName)

    # layout.refresh()
    qgisProject.layoutManager().addLayout(qgisLayout)
    qgisLayout.loadFromTemplate(myDocument, QgsReadWriteContext(), True)

    # Get Data From Template
    # Get Maps
    mainMap = qgisLayout.itemById("mainMap")
    qgisLayout.addLayoutItem(mainMap)
    overviewMap = qgisLayout.itemById("overviewMap")
    qgisLayout.addLayoutItem(overviewMap)

    # Get Frames
    planartFrame = qgisLayout.itemById("planartFrame")
    regionFrame = qgisLayout.itemById("regionFrame")
    parentFrame = qgisLayout.itemById("parentFrame")
    bauabschnittFrame = qgisLayout.itemById("bauabschnittFrame")
    bereichsbezeichnungFrame = qgisLayout.itemById("bereichsbezeichnungFrame")

    # Get Frame-Text
    planartFrameText = qgisLayout.itemById("planartFrameText")
    regionFrameText = qgisLayout.itemById("regionFrameText")
    parentFrameText = qgisLayout.itemById("parentFrameText")
    bauabschnittFrameText = qgisLayout.itemById("bauabschnittFrameText")
    bereichsbezeichnungFrameText = qgisLayout.itemById("bereichsbezeichnungFrameText")
    datumFrameText = qgisLayout.itemById("datumFrameText")

    # Get Impressum
    companyInfo = qgisLayout.itemById("Anschrift")

    # Get Tables
    revisionTable = qgisLayout.itemById("revisionTable").multiFrame()

    # Get Logos
    companyLogo = qgisLayout.itemById("companyLogo")

    # Get Legend
    legendFrame = qgisLayout.itemById("legendFrame")
    legend = qgisLayout.itemById("legend")

    # Get NorthArrow
    northArrow = qgisLayout.itemById("NorthArrow")

    # Set Data In Template
    # Set Table Content
    revisionLayerResult = getLayerByKeyword("revision", qgisProject)
    if revisionLayerResult.isFailure():
        return Result.fail(
            "Revision CSV isn't found in layer tree and can't be loaded to printing header!",
        )
    else:
        revisionLayer = revisionLayerResult.value
        revisionTable.setVectorLayer(revisionLayer)

    companyLogo.setPicturePath(appConfig.companyLogoPath)

    # Set Legend
    legend.setPicturePath(appConfig.getLegendPath(styling, appConfig.pluginDirPath))

    # Set Impressum
    companyInfoResult = appConfig.getCompanyInfoText()
    if companyInfoResult.isFailure():
        infoMessageAggregator.addMessage(
            companyInfoResult.message,
        )
        companyInfo.setText("")
    else:
        companyInfo.setText(str(companyInfoResult.value))

    # Set NorthArrow
    northArrow.setPicturePath(appConfig.northArrowPath)

    # Set Frames
    if composerData is not None:
        regionFrameText.setText(composerData.cluster)
        parentFrameText.setText(composerData.region)
        bauabschnittFrameText.setText(composerData.siteCluster)

    # Set Main Map
    mainMap.setScale(int(printingScale), True)

    # Load Layers for Main Map
    # Get Selected RiMo Layers and append atlasGrid
    mainMapLayerSet = list([lyr for lyr in getRimoLayer(qgisProject, True)])
    mainMapLayerSet.append(atlasGrid)
    mainMapLayerSet.extend([lyr for lyr in getWmsLayer(qgisProject, True)])

    # Transform Bounding Box
    box = atlasGrid.extent()
    source_crs = QgsCoordinateReferenceSystem.fromEpsgId(4326)
    dest_crs = QgsCoordinateReferenceSystem.fromEpsgId(appConfig.getEPSG())
    transform_context = qgisProject.transformContext()
    transform = QgsCoordinateTransform(source_crs, dest_crs, transform_context)
    new_box = transform.transformBoundingBox(box)
    overviewMap.zoomToExtent(new_box)

    # Load Layers for Overview Map
    overviewMapLayerSet = []
    trenchesLayerResult = getLayerByKeyword("Trenches", qgisProject)
    if trenchesLayerResult.isSuccess():
        overviewMapLayerSet.append(trenchesLayerResult.value)

    overviewMapLayerSet.append(atlasGrid)
    overviewMap.setLayers(overviewMapLayerSet)
    overviewMap.setKeepLayerSet(True)

    # Set Overview Map
    overviewMap.overview().setLinkedMap(mainMap)
    qgisInterface.openLayoutDesigner(qgisLayout)

    if (createAtlas == "Atlas") or (createAtlas == "Atlas_StaticLegend"):
        # Create Atlas
        myAtlas = qgisLayout.atlas()
        myAtlas.setEnabled(True)

        # Set OverviewMap in Overview
        overviewMap.setFrameEnabled(True)
        overviewMap.overview()

        # Atlas Settings
        myAtlas.setCoverageLayer(atlasGrid)
        myAtlas.beginRender()
        myAtlas.setFilterExpression('"data_in_g"  = 1')
        myAtlas.setFilterFeatures(True)
        myAtlas.setPageNameExpression("Page")
        myAtlas.setSortExpression("ID")
        myAtlas.setSortFeatures(True)
        myAtlas.setEnabled(False)
        myAtlas.setEnabled(True)
        myAtlas.setFilenameExpression(
            "'{}_{}_{}".format(
                projectState.regionName,
                projectState.clusterName,
                projectState.siteClusterName,
            )
            + '_\'||"Page"'
        )
        mainMap.setAtlasDriven(True)
        mainMap.setScale(int(printingScale), True)
        designer = qgisInterface.openLayoutDesigner(qgisLayout)
        if myAtlas.coverageLayer() is not None:
            designer.setAtlasPreviewEnabled(True)

    if createAtlas != "Atlas":
        # Set OverviewMap in Overview
        overviewMap.overview()
        currentExtent = qgisInterface.mapCanvas().extent()
        mainMap.zoomToExtent(currentExtent)
        mainMap.setScale(int(printingScale), True)

    # Create Watermark if Development
    # if appConfig.isDev():
    # watermark = QgsComposerLabel(qgisLayout)
    # watermark.setText("Attention - Development")
    # watermark.setItemPosition(20, 20)
    # fontsize = qgisLayout.paperHeight() / 6
    # watermark.setFont(QFont("Cambria", fontsize, QFont.Bold))
    # watermark.adjustSizeToText()
    # qgisLayout.addComposerLabel(watermark)

    # Refresh Items
    qgisLayout.refresh()


def createAtlasGrid(
    clusterPolygonLayer,
    currentGroup,
    gridLayerValues,
    qgisProject,
    appConfig: ApplicationConfig,
    projectState: ProjectState,
):
    # Create EPSG Code String from config
    EPSG_Code = f"EPSG:{appConfig.getEPSG()}"

    # Calculate Layer in MGI
    parameter_epsg = {
        "INPUT": clusterPolygonLayer,
        "TARGET_CRS": EPSG_Code,
        "OUTPUT": "memory:epsg",
    }
    result = processing.run("qgis:reprojectlayer", parameter_epsg)
    reprojectedClusterLayer = result["OUTPUT"]

    # Calculate Center Point
    parameter_centr = {
        "INPUT": reprojectedClusterLayer,
        "ALL_PARTS": "TRUE",
        "OUTPUT": "memory:centr",
    }
    result_centr = processing.run("qgis:centroids", parameter_centr)
    centroidLayer = result_centr["OUTPUT"]
    for feat in centroidLayer.getFeatures():
        x = feat.geometry().asPoint()

    # Calculate Boundings
    parameter_layerextent = {"INPUT": clusterPolygonLayer, "OUTPUT": "memory:lrext"}
    resultLayerExtent = processing.run(
        "qgis:polygonfromlayerextent", parameter_layerextent
    )
    boundingPolygon_ = resultLayerExtent["OUTPUT"]

    # Reproject Boundings
    parameter_repro = {
        "INPUT": boundingPolygon_,
        "TARGET_CRS": EPSG_Code,
        "OUTPUT": "memory:repro",
    }
    result_repro = processing.run("qgis:reprojectlayer", parameter_repro)
    boundingPolygon = result_repro["OUTPUT"]

    # get extent from Bounding
    ext = boundingPolygon.extent()
    XMIN = ext.xMinimum()
    YMIN = ext.yMinimum()
    XMAX = ext.xMaximum()
    YMAX = ext.yMaximum()

    for entry in gridLayerValues:
        # Recalculate XMAX - XMIN
        if (XMAX - XMIN) < entry["value"]:
            add = (entry["value"] - (XMAX - XMIN)) / 2
            XMAX = XMAX + add
            XMIN = XMIN - add

        if (YMAX - YMIN) < entry["value"]:
            add = (entry["value"] - (YMAX - YMIN)) / 2
            YMAX = YMAX + add
            YMIN = YMIN - add

        # Calculate Distance & New Boundings
        EW = XMAX - XMIN
        NS = YMAX - YMIN
        aEW = (EW % 10) / 2
        aNS = (NS % 10) / 2

        nXMIN = XMIN - aEW
        nYMIN = YMIN - aNS
        nXMAX = XMAX + aEW
        nYMAX = YMAX + aNS

        fieldAmountEW = int((nXMAX - nXMIN) / entry["value"])
        fieldAmountNS = int((nYMAX - nYMIN) / entry["value"])

        extent = f"{nXMIN}, {nXMAX}, {nYMIN}, {nYMAX}"

        # Create Grid
        param_grid = {
            "TYPE": 2,
            "EXTENT": extent,
            "HSPACING": entry["value"],
            "VSPACING": entry["value"],
            "CRS": EPSG_Code,
            "OUTPUT": "memory:tempgrid",
        }
        result_grid = processing.run("qgis:creategrid", param_grid)
        grid = result_grid["OUTPUT"]

        param_reproject = {
            "INPUT": grid,
            "TARGET_CRS": "EPSG:4326",
            "OUTPUT": "memory:reproject",
        }
        result_reproject = processing.run("qgis:reprojectlayer", param_reproject)
        gridWGS = result_reproject["OUTPUT"]

        gridName = "{}_{}_Grid_{}m".format(
            entry["format"], entry["scale"], entry["value"]
        )
        gridWGS.setName(gridName)

        # Check if Gird covers defined Layers
        trenchesResult = getLayerByKeyword("Trenches", qgisProject)
        buildingsResult = getLayerByKeyword("Buildings", qgisProject)
        if trenchesResult.isFailure() or buildingsResult.isFailure():
            return Result.fail(
                "Layer {} or {} doesn't exist in layer Tree. The printing grid can't be calculated. Atlas printing process is stopped!".format(
                    repr("Trenches"), repr("Buildings")
                ),
            )

        trenchesLayer = trenchesResult.value
        buildingsLayer = buildingsResult.value
        intersect_lyrs_param = [
            {
                "INPUT": gridWGS,
                "PREDICATE": 0,
                "INTERSECT": trenchesLayer,
                "METHOD": 0,
                "OUTPUT": "memory:tempgridintersect",
            },
            {
                "INPUT": gridWGS,
                "PREDICATE": 0,
                "INTERSECT": buildingsLayer,
                "METHOD": 1,
                "OUTPUT": "memory:tempgridintersect",
            },
        ]

        for params in intersect_lyrs_param:
            processing.run("qgis:selectbylocation", params)

        # Add Column for Intersection info
        col_name_intersect = "data_in_g"
        gridWGS.dataProvider().addAttributes(
            [QgsField(col_name_intersect, QMetaType.Int)]
        )
        gridWGS.updateFields()
        intersection_default_value = QgsDefaultValue("0")
        gridWGS.setDefaultValueDefinition(
            gridWGS.fields().indexFromName(col_name_intersect),
            intersection_default_value,
        )

        # Calculate Value
        fields = gridWGS.fields()
        data_in_gIdx = fields.indexFromName(col_name_intersect)
        attrFeatMap = {}

        # Fill intersect info to column
        for feature in gridWGS.getSelectedFeatures():
            attrFeatMap[feature.id()] = {data_in_gIdx: 1}

        gridWGS.dataProvider().changeAttributeValues(attrFeatMap)

        gridWGS.removeSelection()

        # Create FieldNames for Numbers in Atlas print
        # Add ID Column
        gridWGS.dataProvider().addAttributes([QgsField("printID", QMetaType.Int)])
        gridWGS.dataProvider().addAttributes([QgsField("Page", QMetaType.QString)])
        gridWGS.updateFields()
        gridWGS.startEditing()
        fields = gridWGS.fields()

        idIDX = fields.indexOf("printID")
        pageIDX = fields.indexOf("Page")

        # Calculate IDs
        features = gridWGS.getFeatures()
        for count, feature in enumerate(features):
            gridWGS.changeAttributeValue(feature.id(), idIDX, count)
        gridWGS.commitChanges()

        # Sort IDs
        features = gridWGS.getFeatures()
        featureList = []
        for feature in features:
            myValue = feature.attributes()
            appendValue = [0 if isinstance(v, int) else v for v in myValue]
            appendValue = []
            # filter NoneType values
            for v in myValue:
                if v is not None:
                    appendValue.append(v)
                    featureList.append(appendValue)

        featureList.sort(key=lambda x: x[1], reverse=True)
        gridWGS.commitChanges()
        # Fill ID Column
        count = 0
        row = 0
        column = 0
        features = gridWGS.getFeatures()
        fieldAmountEW = fieldAmountEW + 1
        fieldAmountNS = fieldAmountNS + 1
        for feature in features:
            if count > 1:
                if count % fieldAmountNS == 0:
                    column = column + 1
                    row = 0

            pageName = f"{row}--{column}"
            gridWGS.startEditing()
            gridWGS.changeAttributeValue(feature.id(), pageIDX, pageName)
            count = count + 1
            row = row + 1

        # Create / Remove Shapefile
        shpPath = appConfig.getShapeFilePath(gridName, projectState.getProjectName())

        # Remove Shapefile if Exists
        try:
            removeShapefileIfExists(gridName, qgisProject)
        except Exception as e:
            return Result.fail(
                f"Failed to remove shape file due to: {e}\nPlease wait a few seconds and try again."
            )

        # Create Shapefile and add it to canvas
        # update: ab 3.20 writeAsVectorFormatV3
        save_options = QgsVectorFileWriter.SaveVectorOptions()
        save_options.driverName = "ESRI Shapefile"
        save_options.fileEncoding = "UTF-8"
        transform_context = qgisProject.transformContext()
        QgsVectorFileWriter.writeAsVectorFormatV3(
            gridWGS, shpPath, transform_context, save_options
        )

        layer = QgsVectorLayer(shpPath + ".shp", gridName, "ogr")
        LayerMetadataHelper.setMetadata(layer, LayerType.GRID_LAYER, gridName)
        qgisProject.addMapLayer(layer, False)

        setStyle(
            layer,
            layer.name(),
            os.path.join(
                appConfig.getStyleFolderPath(projectState), projectState.currentStyling
            ),
        )

        currentGroup.addLayer(layer)

    return Result.ok(gridName)


def getGridLayerValues(paperFormat, cScale):
    if paperFormat == "A0" and cScale == "500":
        gridLayerValues = [
            {"value": 350, "selected": False, "format": "A0", "scale": "500"}
        ]

    if paperFormat == "A0" and cScale == "1000":
        gridLayerValues = [
            {"value": 700, "selected": False, "format": "A0", "scale": "1000"}
        ]

    if paperFormat == "A1" and cScale == "500":
        gridLayerValues = [
            {"value": 250, "selected": False, "format": "A1", "scale": "500"}
        ]

    if paperFormat == "A1" and cScale == "1000":
        gridLayerValues = [
            {"value": 500, "selected": False, "format": "A1", "scale": "1000"}
        ]

    if paperFormat == "A2" and cScale == "500":
        gridLayerValues = [
            {"value": 180, "selected": False, "format": "A2", "scale": "500"}
        ]

    if paperFormat == "A2" and cScale == "1000":
        gridLayerValues = [
            {"value": 370, "selected": False, "format": "A2", "scale": "1000"}
        ]

    if paperFormat == "A3" and cScale == "500":
        gridLayerValues = [
            {"value": 130, "selected": False, "format": "A3", "scale": "500"}
        ]

    if paperFormat == "A3" and cScale == "1000":
        gridLayerValues = [
            {"value": 250, "selected": True, "format": "A3", "scale": "1000"}
        ]

    if paperFormat == "A4" and cScale == "500":
        gridLayerValues = [
            {"value": 100, "selected": False, "format": "A4", "scale": "500"}
        ]

    if paperFormat == "A4" and cScale == "1000":
        gridLayerValues = [
            {"value": 100, "selected": False, "format": "A4", "scale": "1000"}
        ]

    return gridLayerValues


def upsertComposerTemplatesInLayoutManager(templatePaths: List[str], qgisProject):
    layoutManager = qgisProject.layoutManager()

    for templatePath in templatePaths:
        templateFile = open(templatePath, "r")
        templateContent = templateFile.read()
        templateFile.close()

        domDocument = QDomDocument()
        domDocument.setContent(templateContent, False)

        layout = QgsPrintLayout(qgisProject)
        layout.loadFromTemplate(domDocument, QgsReadWriteContext(), True)

        if not layout.name().strip():
            fileName = os.path.splitext(os.path.basename(templatePath))[0]
            layout.setName(fileName)

        existingLayoutWithName = layoutManager.layoutByName(layout.name())
        if existingLayoutWithName:
            layoutManager.removeLayout(existingLayoutWithName)

        layoutManager.addLayout(layout)


def removeShapefileIfExists(gridName, qgisProject):
    layers = [layer for layer in qgisProject.mapLayers().values()]
    for layer in layers:
        layerResult = getLayerByKeyword(gridName, qgisProject)
        if layerResult.isSuccess():
            shpPath = os.path.dirname(layer.source())

            qgisProject.removeMapLayer(layerResult.value)

            for filename in os.listdir(shpPath):
                if filename.startswith(gridName):
                    os.remove(os.path.join(shpPath, filename))
            return


def checkIfLayoutExists(composerName, qgisProject):
    printLayouts = qgisProject.layoutManager().printLayouts()
    for e in printLayouts:
        name = f"{composerName}.qpt"
        if name == e.name():
            return True

    return False
