'''
This is an enhanced version of LS-Reader which users can customize special feautures with LS-Reader core
'''

from lsreader import D3plotReader, DataType as dt, D3P_Parameter
import numpy as np

class D3plotReaderPro(D3plotReader):
    '''
    Definition of D3plotReaderPro that inherits from D3plotReader
    '''

    def __init__(self, path):
        super().__init__(path)
    
    def solid_nodal_average_signed_von_mises_stress(self, param):
        '''
        Translate solid signed von mises stress data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageStress={}
        solid_signed_von_mises_stress = super().get_data(dt.D3P_SOLID_SIGNED_VON_MISES_STRESS, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_signed_von_mises_stress[ele] / len(node_elements[key])
            node_nodalAverageStress[node_ids[key]] = result

        return node_nodalAverageStress

    def solid_nodal_average_signed_von_mises_strain(self, param):
        '''
        Translate solid signed von mises strain data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageStrain={}
        solid_signed_von_mises_strain = super().get_data(dt.D3P_SOLID_SIGNED_VON_MISES_STRAIN, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_signed_von_mises_strain[ele] / len(node_elements[key])
            node_nodalAverageStrain[node_ids[key]] = result

        return node_nodalAverageStrain

    def solid_nodal_average_1st_principal_stress(self, param):
        '''
        Translate solid p1 stress data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP1={}
        solid_1st = super().get_data(dt.D3P_SOLID_1ST_PRINCIPAL_STRESS, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_1st[ele] / len(node_elements[key])
            node_nodalAverageP1[node_ids[key]] = result

        return node_nodalAverageP1

    def solid_nodal_average_2nd_principal_stress(self, param):
        '''
        Translate solid p2 stress data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP2={}
        solid_2nd = super().get_data(dt.D3P_SOLID_2ND_PRINCIPAL_STRESS, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_2nd[ele] / len(node_elements[key])
            node_nodalAverageP2[node_ids[key]] = result

        return node_nodalAverageP2

    def solid_nodal_average_3rd_principal_stress(self, param):
        '''
        Translate solid p3 stress data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP3={}
        solid_3rd = super().get_data(dt.D3P_SOLID_3RD_PRINCIPAL_STRESS, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_3rd[ele] / len(node_elements[key])
            node_nodalAverageP3[node_ids[key]] = result

        return node_nodalAverageP3

    def solid_nodal_average_max_principal_strain(self, param):
        '''
        Translate solid p1 strain data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP1={}
        solid_1st = super().get_data(dt.D3P_SOLID_MAX_PRINCIPAL_STRAIN, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_1st[ele] / len(node_elements[key])
            node_nodalAverageP1[node_ids[key]] = result

        return node_nodalAverageP1

    def solid_nodal_average_2nd_principal_strain(self, param):
        '''
        Translate solid p2 strain data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP2={}
        solid_2nd = super().get_data(dt.D3P_SOLID_2ND_PRINCIPAL_STRAIN, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_2nd[ele] / len(node_elements[key])
            node_nodalAverageP2[node_ids[key]] = result

        return node_nodalAverageP2

    def solid_nodal_average_min_principal_strain(self, param):
        '''
        Translate solid p3 strain data to nodal average
        input: param --- parameter
        '''
        solids = super().get_data(dt.D3P_SOLID_ID_CONNECTIVITY_MAT, parameter=param)
        # get the all nodes of solid elements
        solid_conn = solids[:, 2:]
        node_elements = {}
        for i, solid in enumerate(solid_conn):
            for node in solid:
                if node-1 in node_elements.keys():
                    node_elements[node-1].append(i)
                else:
                    node_elements[node-1]  = [i]
        
        node_nodalAverageP3={}
        solid_3rd = super().get_data(dt.D3P_SOLID_MIN_PRINCIPAL_STRAIN, parameter=param)
        node_ids = super().get_data(dt.D3P_NODE_IDS, parameter=param)
        for key in node_elements:
            result = 0.0
            for ele in node_elements[key]:
                result += solid_3rd[ele] / len(node_elements[key])
            node_nodalAverageP3[node_ids[key]] = result

        return node_nodalAverageP3