import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from torchvision import models


######################## init weight ##############################################
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm1d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        init.constant_(m.bias.data, 0.0)


####################### l2GL ##################################################
def intradis(x, picnum):
    feats = torch.split(x, 1, 0)
    loss = []
    for i in range(picnum):
        for j in range(i + 1, picnum):
            feat1 = feats[i] - feats[j]
            loss.append(torch.mul(feat1, feat1))
    loss = torch.mean(torch.cat(loss, 0), 0, keepdim=True)
    return loss


def l2loss(x, picnum):
    feats = torch.split(x, picnum, 0)
    y = []
    y1 = []
    if picnum == 1:
        for k in range(len(feats)):
            for m in range(k + 1, len(feats)):
                feat1 = feats[k] - feats[m]
                y1.append(torch.mean(torch.mul(feat1, feat1), 0, keepdim=True))
        y0 = torch.cat(y1, 0)
    else:
        for k in range(len(feats)):
            y.append(intradis(feats[k], picnum))
            for m in range(k + 1, len(feats)):
                feat1 = feats[k] - feats[m]
                y1.append(torch.mean(torch.mul(feat1, feat1), 0, keepdim=True))
        y = torch.cat(y, 0)
        y1 = torch.cat(y1, 0)
        y0 = torch.cat((y, y1), 0)
    return y0


####################### VGL ###################################################
def intravgl(x, y, picnum):
    feats = torch.split(x, 1, 0)
    loss = []
    for i in range(picnum):
        feat1 = feats[i] - y
        loss.append(torch.mul(feat1, feat1))
    loss = torch.mean(torch.cat(loss, 0), 0, keepdim=True)
    return loss


def vglloss(x, picnum):
    feats = torch.split(x, picnum, 0)
    y = []
    y0 = []
    y1 = []
    if picnum == 1:
        vglms = torch.mean(x, 0, keepdim=True)
        for k in range(len(feats)):
            feat1 = feats[k] - vglms
            y1.append(torch.mul(feat1, feat1))
        feat = torch.cat(y1, 0)
    else:
        for k in range(len(feats)):
            vglm = torch.mean(feats[k], 0, keepdim=True)
            vgll = intravgl(feats[k], vglm, picnum)
            y.append(vgll)
            y0.append(vglm)
        y0 = torch.cat(y0, 0)
        vglms = torch.mean(y0, 0, keepdim=True)
        for k in range(len(feats)):
            feat1 = y0[k] - vglms
            y1.append(torch.mul(feat1, feat1))
        y = torch.cat(y, 0)
        y1 = torch.cat(y1, 0)
        feat = torch.cat((y, y1), 0)
    return feat


#############################  Attention ##################################################
class ConvBlock(nn.Module):
    """Basic convolutional block:
    convolution + batch normalization + relu.

    Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
    - in_c (int): number of input channels.
    - out_c (int): number of output channels.
    - k (int or tuple): kernel size.
    - s (int or tuple): stride.
    - p (int or tuple): padding.
    """

    def __init__(self, in_c, out_c, k, s=1, p=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))


class ThreeDAttn3(nn.Module):
    def __init__(self, in_channels, out_channels, reduction_rate=16):
        super(ThreeDAttn3, self).__init__()
        assert in_channels % reduction_rate == 0
        self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
        self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)

        self.conv3 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
        self.conv4 = ConvBlock(in_channels // reduction_rate, out_channels, 1)

    def forward(self, x):
        y = F.avg_pool2d(x, x.size()[2:])
        y = self.conv1(y)
        y = self.conv2(y)
        x = x * y
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.sigmoid(x)
        return x


###############################################################################

# |--Linear--|--bn--|--relu--|--Linear--|
class ClassBlock(nn.Module):
    def __init__(self, input_dim, class_num, dropout=True, relu=True, num_bottleneck=512):
        super(ClassBlock, self).__init__()
        add_block = []
        add_block += [nn.Linear(input_dim, num_bottleneck)]
        add_block += [nn.BatchNorm1d(num_bottleneck)]
        if relu:
            add_block += [nn.LeakyReLU(0.1)]
        if dropout:
            add_block += [nn.Dropout(p=0.5)]
        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_kaiming)

        classifier = []
        classifier += [nn.Linear(num_bottleneck, class_num)]
        classifier = nn.Sequential(*classifier)
        classifier.apply(weights_init_classifier)

        self.add_block = add_block
        self.classifier = classifier

    def forward(self, x):
        x = self.add_block(x)
        x = self.classifier(x)
        return x


# Part Model proposed in Yifan Sun etal. (2018)
class HUANG_3DTA(nn.Module):
    def __init__(self, class_num, picnum, batch_size=30):
        super(HUANG_3DTA, self).__init__()
        self.picnum = picnum
        self.part = 11  # We cut the pool5 to 6 parts
        model_ft = models.resnet50(pretrained=True)
        self.maxpool = nn.MaxPool2d(kernel_size=6, stride=6, padding=0)
        self.model = model_ft
        self.avgpool = nn.AdaptiveAvgPool2d((6, 1))
        self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)
        # remove the final downsample
        self.model.layer4[0].downsample[0].stride = (1, 1)
        self.model.layer4[0].conv2.stride = (1, 1)
        # define 6 classifiers
        setattr(self, 'classifier0', ClassBlock(1024, class_num, True, False, 256))
        setattr(self, 'classifier1', ClassBlock(1024, class_num, True, False, 256))
        for i in range(2, self.part):
            name = 'classifier' + str(i)
            setattr(self, name, ClassBlock(2048, class_num, True, False, 256))
        setattr(self, 'classifier11', ThreeDAttn3(1024, 1024))
        setattr(self, 'classifier12', ThreeDAttn3(1024, 2048))
        setattr(self, 'classifier13', ThreeDAttn3(2048, 2048))
        setattr(self, 'classifier14', ThreeDAttn3(2048, 2048))
        setattr(self, 'classifier15', ThreeDAttn3(1024, 1024))
        setattr(self, 'classifier17', ThreeDAttn3(1024, 1024))
        setattr(self, 'classifier16', ClassBlock(2048, class_num, True, False, 256))

    def forward(self, x):
        part = {}
        predict = {}

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3[0](x)
        x = self.model.layer3[1](x)
        x = self.model.layer3[2](x)
        d = getattr(self, 'classifier17')
        x1 = d(x)
        x = self.model.layer3[3](x)
        x = torch.mul(x, x1)
        d = getattr(self, 'classifier15')
        x1 = d(x)
        x = self.model.layer3[4](x)
        x = torch.mul(x, x1)
        d = getattr(self, 'classifier11')
        x1 = d(x)
        x0 = self.avgpool1(x)
        part[0] = torch.squeeze(x0[:, :, 0])
        c = getattr(self, 'classifier0')
        predict[0] = c(part[0])

        x = self.model.layer3[5](x)
        x = torch.mul(x, x1)
        d = getattr(self, 'classifier12')
        x1 = d(x)
        x0 = self.avgpool1(x)
        part[1] = torch.squeeze(x0[:, :, 0])
        c = getattr(self, 'classifier1')
        predict[1] = c(part[1])

        x = self.model.layer4[0](x)
        x = torch.mul(x, x1)
        d = getattr(self, 'classifier13')
        x1 = d(x)
        x0 = self.avgpool1(x)
        part[2] = torch.squeeze(x0[:, :, 0])
        c = getattr(self, 'classifier2')
        predict[2] = c(part[2])

        x = self.model.layer4[1](x)
        x = torch.mul(x, x1)
        d = getattr(self, 'classifier14')
        x1 = d(x)
        x0 = self.avgpool1(x)
        part[3] = torch.squeeze(x0[:, :, 0])
        c = getattr(self, 'classifier3')
        predict[3] = c(part[3])

        x = self.model.layer4[2](x)
        x = torch.mul(x, x1)
        feats = x
        x0 = self.avgpool1(x)
        part[4] = torch.squeeze(x0[:, :, 0])
        c = getattr(self, 'classifier4')
        predict[4] = c(part[4])

        x = self.avgpool(x)
        x = self.dropout(x)
        # get six part feature batchsize*2048*6
        for i in range(5, 11):
            part[i] = torch.squeeze(x[:, :, i - 5])
            name = 'classifier' + str(i)
            c = getattr(self, name)
            predict[i] = c(part[i])

        y = []
        yp = []

        for i in range(self.part):
            y.append(predict[i])
            yn = torch.mean(torch.mul(part[i], part[i]), 1, keepdim=True)
            yp.append(yn)
        yp = torch.cat(yp, 1)
        yp = torch.mean(yp, 1, keepdim=True)
        c = getattr(self, 'classifier16')
        feats = vglloss(feats, self.picnum)

        ######## avgpool #########
        feats = self.avgpool1(feats)
        feats = torch.squeeze(feats)
        ##########################

        feats = c(feats)
        return y, yp, feats


class HUANG_3DTA_test(nn.Module):
    def __init__(self, model):
        super(HUANG_3DTA_test, self).__init__()
        self.part = 6
        self.model0 = model
        self.model = model.model
        self.avgpool = nn.AdaptiveAvgPool2d((6, 1))
        self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1))
        # remove the final downsample
        self.model.layer4[0].downsample[0].stride = (1, 1)
        self.model.layer4[0].conv2.stride = (1, 1)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3[0](x)
        x = self.model.layer3[1](x)
        x = self.model.layer3[2](x)
        x1 = self.model0.classifier17(x)
        x = self.model.layer3[3](x)
        x = torch.mul(x, x1)
        x1 = self.model0.classifier15(x)
        x = self.model.layer3[4](x)
        x = torch.mul(x, x1)
        x1 = self.model0.classifier11(x)
        x = self.model.layer3[5](x)
        x = torch.mul(x, x1)
        x1 = self.model0.classifier12(x)
        x = self.model.layer4[0](x)
        x = torch.mul(x, x1)
        x1 = self.model0.classifier13(x)
        x = self.model.layer4[1](x)
        x = torch.mul(x, x1)
        x1 = self.model0.classifier14(x)
        x = self.model.layer4[2](x)
        x = torch.mul(x, x1)
        x = self.avgpool(x)
        y1 = x.view(x.size(0), x.size(1), x.size(2))
        return y1
