less than 1 minute read

2421. Number of Good Paths (hard)

time exceeded at 123/134 testcases

class Solution:

    def numberOfGoodPaths(self, vals: List[int], edges: List[List[int]]) -> int:
        tree = defaultdict(list)
        for s, e in edges:
            tree[s].append(e)
            tree[e].append(s)
        
        same_value_nodes = defaultdict(list)
        for i, value in enumerate(vals):
            same_value_nodes[value].append(i)
        
        # print(same_value_nodes)

        def dfs(start, end, par, target):
            nonlocal tree, vals
            if start == end:
                return True
            if vals[start] > target:
                return False
            res = False
            for node in tree[start]:
                if node != par:
                    res |= dfs(node, end, start, target)
            return res

        count = 0
        for nodes in same_value_nodes.values():
            n = len(nodes)
            for i in range(n):
                for j in range(i, n):
                    s, e = nodes[i], nodes[j]
                    if dfs(s, e, -1, vals[s]):
                        count += 1
        return count