OdeToCode IC Logo

Custom Aggregations In LINQ

Saturday, March 29, 2008

Aggregate is a standard LINQ operator for in-memory collections that allows us to build a custom aggregation. Although LINQ provides a few standard aggregation operators, like Count, Min, Max, and Average, if you want an inline implementation of, say, a standard deviation calculation, then the Aggregate extension method is one approach you can use (the other approach being that you could write your own operator).

Let's say we wanted to see the total number of threads running on a machine. We could get that number lambda style, or with a query comprehension, or with a custom aggregate.

var processes = Process.GetProcesses();

int totalThreads = 0;

totalThreads = processes.Sum(p => p.Threads.Count);

totalThreads = (
from process in processes
                
select process.Threads.Count).Sum();            

totalThreads =
     processes.Aggregate(
            0,                                  
// initialize
            (acc, p) => acc += p.Threads.Count, // accumulate
            acc => acc                          // terminate
      );

This particular overloaded version of Aggregate follows a common pattern of "Initialize – Accumulate – Terminate". You can see this pattern in extensible aggregation strategies from Oracle to SQLCLR. The first parameter represents an initialization expression. We need to provide an initialized accumulator – in this case just an integer value of 0.

The second parameter is a Func<int, Process, int> expression that the aggregate method will invoke as it iterates across the sequence of inputs. For each process we get our accumulator value (an int), and a reference to the current process in the iteration stage (a Process), and we return a new accumulator value (an int).

The last parameter is the terminate expression. This is an opportunity to provide any final calculations. For our summation, we just need to return the value in the accumulator.

StdDev

Now, let's compute a more thorough summary of running threads, including a standard deviation. Although we could get away with a simple double accumulator for stddev, we can also use a more sophisticated accumulator to encapsulate some calculations, facilitate unit tests, and make the syntax easier on the eye.

class StdDevAccumulator<TSource>
{        
    
public StdDevAccumulator(IEnumerable<TSource> source,
                            
Func<TSource, double> avgSelector)
    {
        SampleAvg = source.Average(avgSelector);
        SampleCount = source.Count();
    }

    
public StdDevAccumulator<TSource> Accumulate(double value)
    {
        TotalDeviation +=
Math.Pow(value - SampleAvg, 2.0);
        
return this;
    }

    
public double ComputeResult()
    {
        
if (SampleCount < 2)
        {
            
return 0.0;
        }
        
return Math.Sqrt(TotalDeviation / (SampleCount - 1));  
    }

    
public double SampleAvg { get; set; }
    
public int    SampleCount { get; set; }
    
public double TotalDeviation { get; set; }
}

Put the accumulator to use like so:

var processes = Process.GetProcesses();

var summary = new
    {
        TotalProcesses = processes.Count(),
        TotalThreads = processes.Sum(p => p.Threads.Count),
        MinThreads = processes.Min(p => p.Threads.Count),
        MaxThreads = processes.Max(p => p.Threads.Count),
        StdDevThreads = processes.Aggregate(    
                
new StdDevAccumulator<Process>(processes, p => p.Threads.Count),
                (acc, p) => acc.Accumulate(p.Threads.Count),                    
                (acc)    => acc.ComputeResult()
        )
    };