Define a uniform "forest" to be a nonempty forest such that all trees are uniform, and the roots of all the forests have the same value.
Let dp(i,j) be the number of uniform forests in the subtree rooted by node i, and the roots of these forests have the value j.
First, ignoring node i, we can get that dp(i,j)=(∏x∈child(i)(dp(x,j)+1))−1. If we do include node i, we can place this on top of any uniform forest, so we can add the sum of all the current dp values to node dp(i,vi).
Note that this also helps us compute the final answer, as we can use this to fix the topmost node in our uniform subset.
Doing this naively will take n2 time. Let si denote the number of nodes in the subtree rooted by node i. We can notice that the number of nonzero dp values for a particular node i is at most si. Thus, we can only keep these nonzero values in a map. Then, combining the child values can be done with a small to big dfs (i.e. when combining two children, only iterate through the smaller child). We can show that this brings the runtime down to nlogn.
C++ implementation with comments.