class Solution:
def countPairs(self, n: int, edges: List[List[int]]) -> int:
tree = defaultdict(list)
for a, b in edges:
tree[a].append(b)
tree[b].append(a)
island = [-1] * n
color = 1
def dfs(cur, prev):
if island[cur] > 0:
return
island[cur] = color
for node in tree[cur]:
if node != prev:
dfs(node, cur)
for i in range(n):
if island[i] == -1:
dfs(i, -1)
color += 1
print(island)
res = 0
counter = Counter(island)
total = n
for key in counter.keys():
n -= counter[key]
res += counter[key] * n
return res