Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📝 convert nodes_id to edges_id and convert back #102

Open
Yonv1943 opened this issue Apr 9, 2023 · 10 comments
Open

📝 convert nodes_id to edges_id and convert back #102

Yonv1943 opened this issue Apr 9, 2023 · 10 comments
Labels
discussion code understanding

Comments

@Yonv1943
Copy link
Collaborator

Yonv1943 commented Apr 9, 2023

仿真环境需要一个功能:

把储存了节点收缩顺序的list,从 记录两个节点收缩顺序,到记录这两个节点对应的边的收缩顺序。

见代码TNCO_env.py中的:

先创建仿真环境这个类,选择想要转换的电路

def unit_test_convert_node2s_to_edge_sorts():
    gpu_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
    device = th.device(f'cuda:{gpu_id}' if th.cuda.is_available() and gpu_id >= 0 else 'cpu')

    nodes_list, ban_edges = NodesSycamoreN12M14, 0
    # nodes_list, ban_edges = NodesSycamoreN14M14, 0
    # nodes_list, ban_edges = NodesSycamoreN53M12, 0
    # nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=8), 8
    # nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=100), 100
    # nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=2000), 2000
    # from TNCO_env import get_nodes_list_of_tensor_tree
    # nodes_list, ban_edges = get_nodes_list_of_tensor_tree(depth=3), 2 ** (3 - 1)

    env = TensorNetworkEnv(nodes_list=nodes_list, ban_edges=ban_edges, device=device)
    print(f"\nnum_nodes      {env.num_nodes:9}"
          f"\nnum_edges      {env.num_edges:9}"
          f"\nban_edges      {env.ban_edges:9}")

下面演示了把 edge_ary 转化成 node2s 转化回 edge_ary 的过程,调用了两个函数:

  • edge_ary → edge_sort → node2s node2s = env.convert_edge_sort_to_node2s(edge_sort=edge_ary.argsort(dim=0))
  • node2s → edge_sort edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)
    num_envs = 6

    # th.save(edge_arys, 'temp.pth')
    # edge_arys = th.load('temp.pth', map_location=device)

    edge_arys = th.rand((num_envs, env.num_edges - env.ban_edges), device=device)
    edge_ary = edge_arys[0]
    print(edge_ary.argsort().shape)
    print(edge_ary.argsort())
    node2s = env.convert_edge_sort_to_node2s(edge_sort=edge_ary.argsort(dim=0))
    edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)
    print(edge_sort.shape)
    print(edge_sort)

    print(edge_sort - edge_ary.argsort())
    edge_sorts = edge_sort.unsqueeze(0)
    multiple_times = env.get_log10_multiple_times(edge_sorts=edge_sorts)
    print(f"multiple_times(log10) {multiple_times.numpy()}")

输出是:(在这个电路下,nodes_list, ban_edges = NodesSycamoreN12M14, 0

num_nodes             51
num_edges             99
ban_edges              0
torch.Size([99])
tensor([30, 36, 49, 35, 55, 65,  0, 28, 61, 52, 45, 69, 10, 21, 83, 18, 56,  9,
        14, 70, 39, 19, 74, 43, 68, 75, 60, 81, 29, 47, 94, 24, 58, 77, 64, 15,
        13, 72, 87, 32, 71, 51, 85,  6, 44, 34, 96, 40, 38, 97, 46, 53, 82, 84,
        22, 90, 25, 23, 33, 92,  1, 62, 42, 91, 67, 93, 26, 98, 79, 12, 16, 27,
        78, 95,  8, 11, 80, 20,  4, 57, 73, 54,  2,  7, 66,  3,  5, 88, 37, 59,
        17, 48, 50, 41, 86, 89, 76, 63, 31])
torch.Size([99])
tensor([30, 36, 49, 35, 55, 65,  0, 28, 61, 52, 45, 69, 10, 21, 83, 18, 56,  9,
        14, 70, 39, 19, 74, 43, 68, 75, 60, 81, 29, 47, 94, 24, 58, 77, 64, 15,
        13, 72, 87, 32, 71, 51, 85,  6, 44, 34, 96, 40, 38, 97, 46, 53, 82, 84,
        22, 90, 25, 23, 33, 92,  1, 62, 42, 91, 67, 93, 26, 98, 79, 12, 16, 27,
        78, 95,  8, 11, 80, 20,  4, 57, 73, 54,  2,  7, 66,  3,  5, 88, 37, 59,
        17, 48, 50, 41, 86, 89, 76, 63, 31], dtype=torch.int32)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0])
multiple_times(log10) [12.06995569]
@Yonv1943 Yonv1943 added the discussion code understanding label Apr 9, 2023
@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 9, 2023

65e2e9a

上面添加了一个更新,经过校验,张潘用的节点收缩顺序,与我们不同。

按照要收缩的边两端的点的序号 47以及70 ,我们查找 节点连接的无向图,发现这两个点不相连。于是推测我们用的 点的编号不一致。(可我们的点的编号用的就是 sycamore 他们官方提供的编号 方案,而不是我们自己的编号)

经过后面的检查,我发现是因为 他们给收缩后的节点一个新的编号,而我们给收缩后的节点的编号沿用了收缩前的任意一个节点,因此在查找节点相连情况时 出错。(2023-04-14 16:03:48)

请注意:

  • 我们的仿真环境兼容 两个节点收缩后编号为某个节点 的情况。例如, 节点A 与节点B 与节点C收缩后,无论它编号为 A,B或者C,我们的仿真环境都可以处理。
  • 我们的仿真环境,已经使用上面的 nodes_id 与 edges_id 相互转换检验过了。
  • 可能是他们把 sycamore 里的某些 量子比特挖掉了。
# node2s = Node2sSycamoreN53N20Zhang1Pan1V1  # ERROR because of their un standard node_id (from paper)

num_nodes            381
num_edges            754
ban_edges              0
47 tensor([107, 110,  66,  65], dtype=torch.int32) ;;; 70 tensor([246, 152, 112, 111], dtype=torch.int32)
47 tensor([68, 69, 26, 27], dtype=torch.int32) ;;; 70 tensor([137,  90,  28,  48], dtype=torch.int32)

所以会有以下 IndexError

  line 592, in convert_node2s_to_edge_sort
    edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range

他们论文的张量收缩方式,和他们代码的不一样,下面尝试了他们代码的收缩方式,依然有相同的错误:

# node2s = Node2sSycamoreN53N20Zhang1Pan1V2  # ERROR because of their un standard node_id (from code)

num_nodes            381
num_edges            754
ban_edges              0
47 tensor([107, 110,  66,  65], dtype=torch.int32) ;;; 7 tensor([26, 32,  2,  7], dtype=torch.int32)
47 tensor([68, 69, 26, 27], dtype=torch.int32) ;;; 7 tensor([26, 30,  2, 15], dtype=torch.int32)

下面是第三个收缩方式,依然有错:

32 tensor([120,  76,  36,  -1], dtype=torch.int32) ;;; 74 tensor([376, 160, 120, 119], dtype=torch.int32)
32 tensor([74, 52,  9, -1], dtype=torch.int32) ;;; 74 tensor([202,  94,  32,  52], dtype=torch.int32)
52 tensor([119, 122,  76,  75], dtype=torch.int32) ;;; 32 tensor([120,  76,  36,  -1], dtype=torch.int32)
52 tensor([74, 75, 32, 33], dtype=torch.int32) ;;; 32 tensor([74, 52,  9, -1], dtype=torch.int32)
75 tensor([159, 162, 122, 121], dtype=torch.int32) ;;; 94 tensor([208, 199, 160, 159], dtype=torch.int32)
75 tensor([94, 95, 52, 53], dtype=torch.int32) ;;; 94 tensor([118, 114,  74,  75], dtype=torch.int32)
52 tensor([119, 122,  76,  75], dtype=torch.int32) ;;; 75 tensor([159, 162, 122, 121], dtype=torch.int32)
52 tensor([74, 75, 32, 33], dtype=torch.int32) ;;; 75 tensor([94, 95, 52, 53], dtype=torch.int32)
13 tensor([38, 43,  5, 50], dtype=torch.int32) ;;; 33 tensor([75, 78, 38, 37], dtype=torch.int32)
13 tensor([33, 36,  5, 39], dtype=torch.int32) ;;; 33 tensor([52, 53, 13, 10], dtype=torch.int32)
53 tensor([121, 124,  78,  77], dtype=torch.int32) ;;; 13 tensor([38, 43,  5, 50], dtype=torch.int32)
53 tensor([75, 76, 33, 34], dtype=torch.int32) ;;; 13 tensor([33, 36,  5, 39], dtype=torch.int32)

Traceback (most recent call last):
line 669, in convert_node2s_to_edge_sort
    edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range

@spicywei
Copy link
Contributor

spicywei commented Apr 9, 2023

I offer another version to try out

@spicywei
Copy link
Contributor

spicywei commented Apr 9, 2023

The reason for the error may be the use of open quantum bits in their method, which requires a reduction in the number of tensor from 381 to 345, possibly leading to a numbering inconsistency problem between us and them.

@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 13, 2023

这里写一下,如何校验 节点收缩顺序。

下面的函数,根据 节点收缩list node2s: list 提供的收缩顺序,以及我们仿真环境自动生成的 收缩的边的编号表,得到 边的编号的收缩顺序。

    def convert_node2s_to_edge_sort(self, node2s: list) -> TEN:
        edges_ary: TEN = self.edges_ary.cpu()
        nodes_ary: TEN = self.nodes_ary.cpu()

        edge_sort = []
        import numpy as np
        for node_i0, node_i1 in node2s:
            print(f"{node_i0:4} {str(edges_ary[node_i0].numpy()):17}    "
                  f"{node_i1:4} {str(edges_ary[node_i1].numpy()):17}   |"
                  f"{node_i0:4} {str(nodes_ary[node_i0].numpy()):17}    "
                  f"{node_i1:4} {str(nodes_ary[node_i1].numpy()):17}")

            edge_is = np.intersect1d(edges_ary[node_i0], edges_ary[node_i1])
            edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
            edge_sort.append(edge_i)
        edge_sort = th.tensor(edge_sort)
        return edge_sort

输入以下

nodes_list, ban_edges = NodesSycamoreN53M12, 0
Node2sSycamoreN53N20Xu3Wei3 = [
    (32, 9), (190, 189), (3, 0), ...]

node2s = Node2sSycamoreN53N20Xu3Wei3
edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)

出现报错,以及log信息:

num_nodes            381
num_edges            754
ban_edges              0
  32 [120  76  36  -1]       9 [29 36 44 -1]       |  32 [74 52  9 -1]           9 [28 32 36 -1]    
 190 [436 396 352 351]     189 [350 349 568 520]   | 190 [232 212 171 170]     189 [170 147 298 274]
...
    edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range

解读第一个张量收缩过程:

32 [120 76 36 -1] 9 [29 36 44 -1] | 32 [74 52 9 -1] 9 [28 32 36 -1]

  • node2s 的第一组数据是 (32, 9) 希望收缩编号为32 和 9 的两个node
  • 程序在 edges_ary 这个node连接表上查找,可以看到,这两个node 被 编号为36 的edge连起来了:
    • 编号为32 的node 连接了 这些编号为 [120 76 36 -1] 的edge
    • 编号为9 的node 连接了 这些编号为 [29 36 44 -1] 的edge
  • 程序j继续在 nodes_ary 这个edge 连接表上查找,可以看到,这两个node 的确被连起来,因为他们相互记录了对方的 node 编号:
    • 编号为32 的node 连接了 这些编号为 [74 52 9 -1] 的node ,我们能在里面找到 编号9 的node
    • 编号为9 的node 连接了 这些编号为 [28 32 36 -1] 的node ,我们能在里面找到 编号32 的node

继续解读下一个张量收缩过程,发现出错了

这不是因为用户提供的 node_arysnode2s 不匹配,而是因为我没有更新收缩后的张量收缩表
190 [436 396 352 351] 189 [350 349 568 520] | 190 [232 212 171 170] 189 [170 147 298 274]

  • node2s 的第二组数据是 (190, 189) 希望收缩编号为190 和 189 的两个node
  • 程序在 edges_ary 这个node连接表上查找,可以看到,这两个node 连接的nodes 取交集发现是空集,因此没有相连,无法收缩:
    • 编号为190 的node 连接了 这些编号为 [436 396 352 351] 的edge
    • 编号为189 的node 连接了 这些编号为 [350 349 568 520] 的edge
  • 程序j继续在 nodes_ary 这个edge 连接表上查找,可以看到,这两个node 没有相连,因为他们记录了的 node 编号没有对方:
    • 编号为190 的node 连接了 这些编号为 [232 212 171 170] 的node ,我们不能在里面找到 编号189 的node
    • 编号为189 的node 连接了 这些编号为 [170 147 298 274] 的node ,我们不能在里面找到 编号190 的node

@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 14, 2023

问题已经解决,是我的 转换代码有问题。

  • 先前是因为别人的张量收缩顺序,给收缩后的张量一个新的 node_id ,所以出错。改成顶替掉收缩前的任意一个 node_id 就好
  • 后来出错,是因为张量 node_i0node_i1 收缩后,应该给哪个张量的 node_id 呢? 我的代码没有做这个适配,因此出错
  • 最后改成:只要收缩后的node_id 顶替掉原本收缩前的任意一个 node_id 就好
Node2sSycamoreN53N20Zhang1Pan1V4 = [
    (32, 9), (190, 189), (3, 0), (6, 1), (13, 5), (169, 127), (127, 146), (192, 152), (193, 153), (196, 113),
    (197, 175), (201, 74), (205, 160), (208, 164), (210, 167), (7, 2), (12, 4), (16, 10), (17, 11), (18, 14), (23, 8),
    (35, 15), (42, 20), (20, 62), (43, 22), (45, 24), (24, 66), (46, 25), (47, 26), (48, 28), (28, 70), (49, 29),
    (50, 30), (51, 27), (52, 33), (53, 34), (56, 36), (36, 78), (57, 37), (58, 38), (59, 39), (39, 81), (60, 40),
    (61, 41), (41, 83), (41, 103), (63, 21), (64, 44), (76, 54), (77, 55), (55, 97), (86, 65), (87, 67), (88, 68),
    (89, 69), (90, 71), (91, 72), (92, 73), (94, 75), (98, 79), (99, 80), (101, 82), (104, 19), (105, 84), (107, 85),
    (115, 95), (116, 96), (117, 93), (121, 100), (124, 102), (130, 108), (131, 111), (132, 106), (106, 149), (133, 114),
    (136, 109), (137, 118), (138, 110), (139, 120), (140, 112), (141, 122), (142, 123), (144, 119), (145, 126),
    (126, 168), (148, 128), (128, 170), (150, 129), (129, 171), (151, 135), (157, 134), (158, 143), (172, 147),
    (173, 154), (174, 155), (178, 159), (180, 161), (181, 162), (182, 163), (183, 156), (184, 165), (185, 166),
    (194, 176), (195, 177), (199, 198), (203, 202), (206, 186), (207, 187), (209, 188), (109, 152), (118, 160),
    (118, 179), (110, 153), (122, 164), (167, 125), (21, 20), (44, 22), (54, 34), (67, 24), (71, 28), (79, 36),
    (82, 39), (135, 108), (155, 106), (166, 143), (30, 31), (147, 191), (159, 200), (202, 204), (159, 113), (110, 161),
    (122, 156), (125, 119), (188, 127), (20, 84), (22, 65), (24, 25), (28, 29), (36, 37), (39, 40), (106, 112),
    (143, 102), (147, 189), (118, 175), (113, 93), (122, 114), (127, 165), (202, 74), (93, 109), (22, 8), (24, 68),
    (28, 72), (36, 80), (106, 163), (147, 128), (34, 11), (30, 2), (127, 123), (74, 9), (20, 0), (119, 41), (93, 69),
    (8, 85), (24, 26), (36, 38), (106, 100), (128, 19), (102, 96), (123, 134), (9, 33), (41, 126), (114, 118),
    (19, 108), (69, 27), (2, 4), (123, 95), (9, 75), (8, 1), (19, 129), (27, 73), (2, 15), (100, 55), (41, 10),
    (19, 154), (9, 5), (95, 14), (19, 111), (55, 177), (19, 120), (5, 162), (14, 176), (5, 198), (5, 187), (5, 186),
    (5, 114), (5, 110), (5, 28), (5, 10), (5, 14), (5, 36), (5, 39), (5, 11), (5, 96), (5, 19), (5, 55), (5, 27),
    (5, 2), (2, 24), (2, 1), (1, 0)
]

对于上面的输入,它有以下输出 multiple_times(log10) [17.49347767],验证完成

num_nodes            211
num_edges            414
ban_edges              0
torch.Size([210])
tensor([ 36, 354,   0,   1,   5, 309, 310, 358, 361, 369, 372, 385, 397, 406,
        413,   2,   4,   8,   9,  10,  19,  42,  56,  96,  57,  62, 104,  64,
         66,  68, 112,  70,  72,  73,  75,  77,  84, 128,  86,  88,  90, 133,
         92,  94, 137, 177,  97,  99, 123, 125, 165, 143, 145, 147, 149, 151,
        153, 155, 159, 167, 169, 173, 180, 182, 186, 201, 203, 206, 213, 219,
        231, 233, 236, 269, 237, 244, 245, 248, 249, 252, 253, 255, 260, 261,
        307, 268, 312, 272, 314, 273, 286, 287, 317, 318, 320, 328, 332, 334,
        336, 339, 340, 342, 364, 367, 379, 391, 399, 403, 409, 275, 291, 330,
        277, 299, 305,  98, 100,  80, 106, 152, 168, 136, 274, 321, 304,  71,
        355, 381, 395, 370, 333, 338, 306, 348, 139, 144, 105, 113, 129,  89,
        281, 257, 353, 373, 289, 284, 349, 388, 329,  59, 148, 154, 170, 337,
        350,  39,  32, 264, 120,  13, 221, 189, 141, 107, 131, 297, 266, 258,
        285,  76, 411, 322, 232, 109,  35, 256, 160,  58, 357, 117,  41, 214,
         54, 356,  38, 225, 280, 366, 250, 389, 324, 378, 347, 404, 384, 345,
        200,  37, 162,  43,  49, 161, 224, 288, 352, 195,  33, 192,  34,   3])
Not Standard edge_sorts
multiple_times(log10) [17.49347767]

@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 14, 2023

经过检查,我们发现了第二个错误:

代码路径是 ...\Python39\Lib\site-packages\opt_einsum\contract.py

    for cnum, contract_inds in enumerate(path):
        # Make sure we remove inds from right to left
        contract_inds = tuple(sorted(list(contract_inds), reverse=True))

        contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
        out_inds, input_sets, idx_removed, idx_contract = contract_tuple

        # Compute cost, scale, and size
        cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
        cost_list.append(cost)
        print(cost)  # -----------------------------------------------------------------> 打印出每一次张量收缩产生的乘法次数
        scale_list.append(len(idx_contract))
        size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))

        tmp_inputs = [input_list.pop(x) for x in contract_inds]
        tmp_shapes = [input_shps.pop(x) for x in contract_inds]

如上面代码所示,我看了他们 opt_einsum 的源代码,把他们每一次收缩的乘法次数print出来,整理到 excel 表格里

aa72fd7a4b963f0fa4972e4d857c24d

蓝色的是他们的正确结果,红色的是我们的,每一次的计算都相差 2**1

d63ab9e6498ef7acf4b088031bccdc5

我们初始化的时候,给每个节点它自己多送了一个 2 ** 1 标记,然后在计算重复的乘法时,我们要减去 被 bool_ary 标记出来的 节点,然后多减了它自己。

因此,一行代码就能修复

            '''calculate the multiple and avoid repeat'''
            ct_dimss = dims_tens[env_is, node_i0s] + dims_tens[env_is, node_i1s] * if_diffs.unsqueeze(1)
            ct_bools = bool_tens[env_is, node_i0s] | bool_tens[env_is, node_i1s]
            # assert ct_dimss.shape == (num_envs, num_nodes)
            # assert ct_bools.shape == (num_envs, num_nodes)

            # 初始化的时候,给每个节点它自己多送了一个 2 ** 1 标记,排除重复的乘法时,会多减了它自己,下面的代码把它加回去
            pow_times = ct_dimss.sum(dim=1) - (ct_dimss * ct_bools).sum(dim=1) * 0.5 + 1  # --------> 多了一个加一,修复了bug
            pow_timess[:, i] = pow_times * if_diffs

跑出来的结果是:

node2s = Node2sSycamoreN53N20Test1  
# log10(multiple_times) = 
他们的结果 25.6106868931126
我们的结果 25.6106813   # power_max - 960
我们的结果 25.61068416  # power_max - 512

用Python自带的 int 大数计算得到的乘法次数
真实的结果 40802511241875888470868352
真实的结果 25.6106868931126

我们的结果当前能精确到有效数字7位,我可以继续改进,这就和这个 issue 无关了。

精度和下面的代码有关,下面的代码为了用 float64 计算更大的数值,减去了 (power_max - 960),但是损失了精度。

        # 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
        adj_pow_times = pow_timess.max(dim=1)[0] - 960  # automatically set `max - 960`, 960 < the limit 1024,
        multiple_times = (2 ** (pow_timess - adj_pow_times.unsqueeze(1))).sum(dim=1)
        multiple_times = multiple_times.log10() + adj_pow_times / th.log2(th.tensor((10,), device=device))
        # adj_pow_times / th.log2(th.tensor((10, ), device=device))  # Change of Base Formula

@Yonv1943
Copy link
Collaborator Author

以下提供一种缓慢但是完全不损失精度的计算方法
结果是:

multiple_times(log10) [25.61068689]
diff 0.000e+00

代码是:

        # 缓慢但是完全不损失精度的计算方法
        multiple_times = []
        pow_timess = pow_timess.cpu().numpy()
        for env_id in range(num_envs):
            multiple_time = 0
            for pow_time in pow_timess[env_id, :]:
                multiple_time = multiple_time + 2 ** pow_time
            multiple_time = math_log10(multiple_time)
            multiple_times.append(multiple_time)
        multiple_times = th.tensor(multiple_times, dtype=th.float64)

@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 24, 2023

重要补充:
我认为 opt_einsum 的计算出错了,而不是我们出错。下面是之前将我们 TensorNetworkEnv 与 opt_sinsum 的结果进行的对比:

经过检查,我们发现了第二个错误:

代码路径是 ...\Python39\Lib\site-packages\opt_einsum\contract.py

    for cnum, contract_inds in enumerate(path):
        # Make sure we remove inds from right to left
        contract_inds = tuple(sorted(list(contract_inds), reverse=True))

        contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
        out_inds, input_sets, idx_removed, idx_contract = contract_tuple

        # Compute cost, scale, and size
        cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
        cost_list.append(cost)
        print(cost)  # -----------------------------------------------------------------> 打印出每一次张量收缩产生的乘法次数
        scale_list.append(len(idx_contract))
        size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))

        tmp_inputs = [input_list.pop(x) for x in contract_inds]
        tmp_shapes = [input_shps.pop(x) for x in contract_inds]

如上面代码所示,我看了他们 opt_einsum 的源代码,把他们每一次收缩的乘法次数print出来,整理到 excel 表格里

aa72fd7a4b963f0fa4972e4d857c24d

蓝色的是他们的正确结果,红色的是我们的,每一次的计算都相差 2**1

下面讨论的 +1,其实是最终的乘法次数 + log10(2**1)

  1. @spicywei 自己手动计算了 TensorTrain (6个节点,不包含虚拟节点)的结果,发现不需要 +1
  2. 我们查看了 sycamore电路的前几个收缩产生的乘法次数,发现不需要 +1,而是opt_einsum 出错了。

检查过程:

  1. 在 Node2sSycamoreN53N20 的电路里,根据张量收缩顺序 Node2sSycamoreN53N20Test2 = [(32, 9), (360, 359), ...] ,我们先收缩 编号 32 和 编号 9 的这两个节点
  2. 编号为32 的node 连接了 这些编号为 [74 52 9 -1] 的node ,我们能在里面找到 编号9 的node,编号为9 的node 连接了 这些编号为 [28 32 36 -1] 的node ,我们能在里面找到 编号32 的node
  3. 计算的乘法次数如下:编号32外接的边是 3条,编号9外接的边是3条,这两个节点中间有一条要收缩,乘法次数为 2 ** (3 + 3 - 1) == 2 ** 5
  4. 如下面表格截图所示, opt_einsum 算出来是 2 ** 6, 而我们算出来是 2 ** 5,我们是对的,因此不需要给乘法次数加上 log10(2**1)

如下方截图所示:收缩 编号 32 和 编号 9 的这两个节点 产生的乘法次数被记录在第一行,其中左列是 opt_einsum 的结果,右列是 我们TensorNetworkEnv 的结果。

image

@spicywei
Copy link
Contributor

我补充一下发现问题的过程:
对于opt_einsum的random greedy的过程:
69243d44c94b3364de6ec5e966b010d
每次可以给出对应收缩的scaling(是正确的),按他给定的order计算结果应该为:log(224) = 2.350248018,而他的结果却为:2.65127,多了个log(2),补充一个草图:
059ac383c543d4aeac131d225ac2c74
这一步表明我们不应该因为他多算了log(2),而去给我们的代码增添这个误差。
为了确保正确性,我再次展示我们的结果,与log(224)相等 :
image

@ZhangAIPI
Copy link
Contributor

我在本地对上述过程进行了二次核验,确实存在opt_einsum计算错误的问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion code understanding
Projects
None yet
Development

No branches or pull requests

3 participants