亚洲国产精品乱码一区二区,美景房屋2免费观看,哎呀哎呀在线观看视频高清国语,从镜子里看我是怎么C哭你

Article / 文章中心

圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算

發(fā)布時(shí)間:2022-08-20 點(diǎn)擊數(shù):2556

圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算

注:大家覺得博客好的話,別忘了點(diǎn)贊收藏呀,本人每周都會(huì)更新關(guān)于人工智能和大數(shù)據(jù)相關(guān)的內(nèi)容,內(nèi)容多為原創(chuàng),Python Java Scala SQL 代碼,CV NLP 推薦系統(tǒng)等,Spark Flink Kafka Hbase Hive Flume等等~寫的都是純干貨,各種頂會(huì)的論文解讀,一起進(jìn)步。
今天和大家分享一篇關(guān)于圖神經(jīng)網(wǎng)絡(luò)相似度計(jì)算的論文
SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
#博學(xué)谷IT學(xué)習(xí)技術(shù)支持#

 

前言



圖神經(jīng)網(wǎng)絡(luò)是當(dāng)下比較火的模型之一,使用神經(jīng)網(wǎng)絡(luò)來學(xué)習(xí)圖結(jié)構(gòu)數(shù)據(jù),提取和發(fā)掘圖結(jié)構(gòu)數(shù)據(jù)中的特征和模式,滿足聚類、分類、預(yù)測、分割、生成等圖學(xué)習(xí)任務(wù)需求的算法。本文是主要通過圖神經(jīng)網(wǎng)絡(luò)來對(duì)兩個(gè)圖的相似性進(jìn)行快速打分的模型。

一、訓(xùn)練數(shù)據(jù)



本文采用torch內(nèi)置數(shù)據(jù)集GEDDataset,直接調(diào)用就可以了,數(shù)據(jù)集一共有700個(gè)圖,每個(gè)圖最多有10個(gè)點(diǎn)組成,每個(gè)點(diǎn)由29種特征組成

    代碼如下(示例):

 

 def process_dataset(self):
     """
     Downloading and processing dataset.
     """
     print("\nPreparing dataset.\n")

     self.training_graphs = GEDDataset(
         "datasets/{}".format(self.args.dataset), self.args.dataset, train=True
     )
     self.testing_graphs = GEDDataset(
         "datasets/{}".format(self.args.dataset), self.args.dataset, train=False
     )

二、模型的輸入

每次輸入兩幅圖,包含邊的信息了,點(diǎn)的特征

代碼如下(示例):

 def forward(self, data):
     edge_index_1 = data["g1"].edge_index
     edge_index_2 = data["g2"].edge_index
     features_1 = data["g1"].x
     print(features_1.shape)
     features_2 = data["g2"].x
     print(features_2.shape)
     batch_1 = (
         data["g1"].batch
         if hasattr(data["g1"], "batch")
         else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes)
     )
     batch_2 = (
         data["g2"].batch
         if hasattr(data["g2"], "batch")
         else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes)
     )

三、圖神經(jīng)網(wǎng)絡(luò)提取更新每個(gè)點(diǎn)的信息

這里運(yùn)用直方圖方式做特征比較新穎。

    def convolutional_pass(self, edge_index, features):
        """
        Making convolutional pass.
        :param edge_index: Edge indices.
        :param features: Feature matrix.
        :return features: Abstract feature matrix.
        """
        features = self.convolution_1(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_2(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_3(features, edge_index)
        return features
#每個(gè)點(diǎn)都走三層gcn
abstract_features_1 = self.convolutional_pass(edge_index_1, features_1)
print(abstract_features_1.shape)
abstract_features_2 = self.convolutional_pass(edge_index_2, features_2)
print(abstract_features_2.shape)

四、計(jì)算點(diǎn)和點(diǎn)之間的關(guān)系得到直方圖特征

    def calculate_histogram(
        self, abstract_features_1, abstract_features_2, batch_1, batch_2
    ):
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)
        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()

        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))

        scores = torch.matmul(
            abstract_features_1, abstract_features_2.permute([0, 2, 1])
        ).detach()
        hist_list = []
        for i, mat in enumerate(scores):
            mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)
            hist = torch.histc(mat, bins=self.args.bins)
            hist = hist / torch.sum(hist)
            hist = hist.view(1, -1)
            hist_list.append(hist)
        print(torch.stack(hist_list).view(-1, self.args.bins).shape)
        return torch.stack(hist_list).view(-1, self.args.bins)
if self.args.histogram:
    hist = self.calculate_histogram(
        abstract_features_1, abstract_features_2, batch_1, batch_2
    )

 

四、Attention Layer 得到圖的特征

    def forward(self, x, batch, size=None):
        size = batch[-1].item() + 1 if size is None else size
        mean = scatter_mean(x, batch, dim=0, dim_size=size)
        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))
        coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))
        weighted = coefs.unsqueeze(-1) * x
       
        return scatter_add(weighted, batch, dim=0, dim_size=size)
       
pooled_features_1 = self.attention(abstract_features_1, batch_1)
pooled_features_2 = self.attention(abstract_features_2, batch_2)

五、運(yùn)用NTN網(wǎng)絡(luò)計(jì)算圖和圖之間的關(guān)系得到特征

def forward(self, embedding_1, embedding_2):
    batch_size = len(embedding_1)
    scoring = torch.matmul(
        embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
    )
    scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) #filters_3可以理解成找多少種關(guān)系
    scoring = torch.matmul(
        scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
    ).view(batch_size, -1)
    combined_representation = torch.cat((embedding_1, embedding_2), 1)
    block_scoring = torch.t(
        torch.mm(self.weight_matrix_block, torch.t(combined_representation))
    )
    scores = F.relu(scoring + block_scoring + self.bias.view(-1))
    return scores

六、預(yù)測得到模型的結(jié)果

 def process_batch(self, data):
    self.optimizer.zero_grad()
    data = self.transform(data)
    target = data["target"]
    prediction = self.model(data)
    loss = F.mse_loss(prediction, target, reduction="sum")
    loss.backward()
    self.optimizer.step()
    return loss.item()

總結(jié)

本文通過點(diǎn)和點(diǎn)的比較,加上圖和圖的比較,結(jié)合在一起,最后計(jì)算出兩幅圖的相似度。其中運(yùn)用到GCN ,NTN,ATTENTION,直方圖等方法。較為有創(chuàng)意。