File: Utilities\ValueSetFactory.NumericValueSet.cs
Web Access
Project: ..\..\..\src\Compilers\CSharp\Portable\Microsoft.CodeAnalysis.CSharp.csproj (Microsoft.CodeAnalysis.CSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;
 
namespace Microsoft.CodeAnalysis.CSharp
{
    using static BinaryOperatorKind;
 
    internal static partial class ValueSetFactory
    {
        /// <summary>
        /// The implementation of a value set for an numeric type <typeparamref name="T"/>.
        /// </summary>
        private sealed class NumericValueSet<T, TTC> : IValueSet<T> where TTC : struct, INumericTC<T>
        {
            private readonly ImmutableArray<(T first, T last)> _intervals;
 
            public static readonly NumericValueSet<T, TTC> AllValues = new NumericValueSet<T, TTC>(default(TTC).MinValue, default(TTC).MaxValue);
 
            public static readonly NumericValueSet<T, TTC> NoValues = new NumericValueSet<T, TTC>(ImmutableArray<(T first, T last)>.Empty);
 
            internal NumericValueSet(T first, T last) : this(ImmutableArray.Create((first, last)))
            {
                Debug.Assert(default(TTC).Related(LessThanOrEqual, first, last));
            }
 
            internal NumericValueSet(ImmutableArray<(T first, T last)> intervals)
            {
#if DEBUG
                TTC tc = default;
                Debug.Assert(intervals.Length == 0 || tc.Related(GreaterThanOrEqual, intervals[0].first, tc.MinValue));
                for (int i = 0, n = intervals.Length; i < n; i++)
                {
                    Debug.Assert(tc.Related(LessThanOrEqual, intervals[i].first, intervals[i].last));
                    if (i != 0)
                    {
                        // intervals are in increasing order with a gap between them
                        Debug.Assert(tc.Related(LessThan, tc.Next(intervals[i - 1].last), intervals[i].first));
                    }
                }
#endif
                _intervals = intervals;
            }
 
            public bool IsEmpty => _intervals.Length == 0;
 
            ConstantValue IValueSet.Sample
            {
                get
                {
                    if (IsEmpty)
                        throw new ArgumentException();
 
                    // Prefer a value near zero.
                    var tc = default(TTC);
                    var gz = NumericValueSetFactory<T, TTC>.Instance.Related(BinaryOperatorKind.GreaterThanOrEqual, tc.Zero);
                    var t = (NumericValueSet<T, TTC>)this.Intersect(gz);
                    if (!t.IsEmpty)
                        return tc.ToConstantValue(t._intervals[0].first);
                    return tc.ToConstantValue(this._intervals[this._intervals.Length - 1].last);
                }
            }
 
            public bool Any(BinaryOperatorKind relation, T value)
            {
                TTC tc = default;
                switch (relation)
                {
                    case LessThan:
                    case LessThanOrEqual:
                        return _intervals.Length > 0 && tc.Related(relation, _intervals[0].first, value);
                    case GreaterThan:
                    case GreaterThanOrEqual:
                        return _intervals.Length > 0 && tc.Related(relation, _intervals[_intervals.Length - 1].last, value);
                    case Equal:
                        return anyIntervalContains(0, _intervals.Length - 1, value);
                    default:
                        throw ExceptionUtilities.UnexpectedValue(relation);
                }
 
                bool anyIntervalContains(int firstIntervalIndex, int lastIntervalIndex, T value)
                {
                    while (true)
                    {
                        if (lastIntervalIndex < firstIntervalIndex)
                            return false;
 
                        if (lastIntervalIndex == firstIntervalIndex)
                            return tc.Related(GreaterThanOrEqual, value, _intervals[lastIntervalIndex].first) && tc.Related(LessThanOrEqual, value, _intervals[lastIntervalIndex].last);
 
                        int midIndex = firstIntervalIndex + (lastIntervalIndex - firstIntervalIndex) / 2;
                        if (tc.Related(LessThanOrEqual, value, _intervals[midIndex].last))
                            lastIntervalIndex = midIndex;
                        else
                            firstIntervalIndex = midIndex + 1;
                    }
                }
            }
 
            bool IValueSet.Any(BinaryOperatorKind relation, ConstantValue value) => value.IsBad || Any(relation, default(TTC).FromConstantValue(value));
 
            public bool All(BinaryOperatorKind relation, T value)
            {
                if (_intervals.Length == 0)
                    return true;
 
                TTC tc = default;
                switch (relation)
                {
                    case LessThan:
                    case LessThanOrEqual:
                        return tc.Related(relation, _intervals[_intervals.Length - 1].last, value);
                    case GreaterThan:
                    case GreaterThanOrEqual:
                        return tc.Related(relation, _intervals[0].first, value);
                    case Equal:
                        return _intervals.Length == 1 && tc.Related(Equal, _intervals[0].first, value) && tc.Related(Equal, _intervals[0].last, value);
                    default:
                        throw ExceptionUtilities.UnexpectedValue(relation);
                }
            }
 
            bool IValueSet.All(BinaryOperatorKind relation, ConstantValue value) => !value.IsBad && All(relation, default(TTC).FromConstantValue(value));
 
            public IValueSet<T> Complement()
            {
                if (_intervals.Length == 0)
                    return AllValues;
 
                TTC tc = default;
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
 
                // add a prefix if apropos.
                if (tc.Related(LessThan, tc.MinValue, _intervals[0].first))
                {
                    builder.Add((tc.MinValue, tc.Prev(_intervals[0].first)));
                }
 
                // add the in-between intervals
                int lastIndex = _intervals.Length - 1;
                for (int i = 0; i < lastIndex; i++)
                {
                    builder.Add((tc.Next(_intervals[i].last), tc.Prev(_intervals[i + 1].first)));
                }
 
                // add a suffix if apropos
                if (tc.Related(LessThan, _intervals[lastIndex].last, tc.MaxValue))
                {
                    builder.Add((tc.Next(_intervals[lastIndex].last), tc.MaxValue));
                }
 
                return new NumericValueSet<T, TTC>(builder.ToImmutableAndFree());
            }
 
            IValueSet IValueSet.Complement() => this.Complement();
 
            public IValueSet<T> Intersect(IValueSet<T> o)
            {
                var other = (NumericValueSet<T, TTC>)o;
                TTC tc = default;
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                var left = this._intervals;
                var right = other._intervals;
                int l = 0;
                int r = 0;
                while (l < left.Length && r < right.Length)
                {
                    var leftInterval = left[l];
                    var rightInterval = right[r];
                    if (tc.Related(LessThan, leftInterval.last, rightInterval.first))
                    {
                        l++;
                    }
                    else if (tc.Related(LessThan, rightInterval.last, leftInterval.first))
                    {
                        r++;
                    }
                    else
                    {
                        Add(builder, Max(leftInterval.first, rightInterval.first), Min(leftInterval.last, rightInterval.last));
                        if (tc.Related(LessThan, leftInterval.last, rightInterval.last))
                        {
                            l++;
                        }
                        else if (tc.Related(LessThan, rightInterval.last, leftInterval.last))
                        {
                            r++;
                        }
                        else
                        {
                            l++;
                            r++;
                        }
                    }
                }
 
                return new NumericValueSet<T, TTC>(builder.ToImmutableAndFree());
            }
 
            /// <summary>
            /// Add an interval to the end of the builder.
            /// </summary>
            private static void Add(ArrayBuilder<(T first, T last)> builder, T first, T last)
            {
                TTC tc = default;
                Debug.Assert(tc.Related(LessThanOrEqual, first, last));
                Debug.Assert(tc.Related(GreaterThanOrEqual, first, tc.MinValue));
                Debug.Assert(tc.Related(LessThanOrEqual, last, tc.MaxValue));
                Debug.Assert(builder.Count == 0 || tc.Related(LessThanOrEqual, builder.Last().first, first));
                if (builder.Count > 0 && (tc.Related(Equal, tc.MinValue, first) || tc.Related(GreaterThanOrEqual, builder.Last().last, tc.Prev(first))))
                {
                    // merge with previous interval when adjacent
                    var oldLastInterval = builder.Pop();
                    oldLastInterval.last = Max(last, oldLastInterval.last);
                    builder.Push(oldLastInterval);
                }
                else
                {
                    builder.Add((first, last));
                }
            }
            private static T Min(T a, T b)
            {
                TTC tc = default;
                return tc.Related(LessThan, a, b) ? a : b;
            }
 
            private static T Max(T a, T b)
            {
                TTC tc = default;
                return tc.Related(LessThan, a, b) ? b : a;
            }
 
            IValueSet IValueSet.Intersect(IValueSet other) => this.Intersect((IValueSet<T>)other);
 
            public IValueSet<T> Union(IValueSet<T> o)
            {
                var other = (NumericValueSet<T, TTC>)o;
                TTC tc = default;
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                var left = this._intervals;
                var right = other._intervals;
                int l = 0;
                int r = 0;
                while (l < left.Length && r < right.Length)
                {
                    var leftInterval = left[l];
                    var rightInterval = right[r];
                    if (tc.Related(LessThan, leftInterval.last, rightInterval.first))
                    {
                        Add(builder, leftInterval.first, leftInterval.last);
                        l++;
                    }
                    else if (tc.Related(LessThan, rightInterval.last, leftInterval.first))
                    {
                        Add(builder, rightInterval.first, rightInterval.last);
                        r++;
                    }
                    else
                    {
                        Add(builder, Min(leftInterval.first, rightInterval.first), Max(leftInterval.last, rightInterval.last));
                        l++;
                        r++;
                    }
                }
 
                while (l < left.Length)
                {
                    var leftInterval = left[l];
                    Add(builder, leftInterval.first, leftInterval.last);
                    l++;
                }
 
                while (r < right.Length)
                {
                    var rightInterval = right[r];
                    Add(builder, rightInterval.first, rightInterval.last);
                    r++;
                }
 
                return new NumericValueSet<T, TTC>(builder.ToImmutableAndFree());
            }
 
            IValueSet IValueSet.Union(IValueSet other) => this.Union((IValueSet<T>)other);
 
            /// <summary>
            /// Produce a random value set for testing purposes.
            /// </summary>
            internal static IValueSet<T> Random(int expectedSize, Random random)
            {
                TTC tc = default;
                T[] values = new T[expectedSize * 2];
                for (int i = 0, n = expectedSize * 2; i < n; i++)
                {
                    values[i] = tc.Random(random);
                }
                Array.Sort(values);
                var builder = ArrayBuilder<(T first, T last)>.GetInstance();
                for (int i = 0, n = values.Length; i < n; i += 2)
                {
                    T first = values[i];
                    T last = values[i + 1];
                    Add(builder, first, last);
                }
 
                return new NumericValueSet<T, TTC>(builder.ToImmutableAndFree());
            }
 
            /// <summary>
            /// A string representation for testing purposes.
            /// </summary>
            public override string ToString()
            {
                TTC tc = default;
                return string.Join(",", this._intervals.Select(p => $"[{tc.ToString(p.first)}..{tc.ToString(p.last)}]"));
            }
 
            public override bool Equals(object? obj) =>
                obj is NumericValueSet<T, TTC> other &&
                this._intervals.SequenceEqual(other._intervals);
 
            public override int GetHashCode()
            {
                return Hash.Combine(Hash.CombineValues(_intervals), _intervals.Length);
            }
        }
    }
}