File: InternalUtilities\OneOrMany.cs
Web Access
Project: ..\..\..\src\Compilers\Core\Portable\Microsoft.CodeAnalysis.csproj (Microsoft.CodeAnalysis)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.PooledObjects;
 
namespace Roslyn.Utilities
{
    /// <summary>
    /// Represents a single item or many items (including none).
    /// </summary>
    /// <remarks>
    /// Used when a collection usually contains a single item but sometimes might contain multiple.
    /// </remarks>
    internal readonly struct OneOrMany<T>
    {
        public static readonly OneOrMany<T> Empty = new OneOrMany<T>(ImmutableArray<T>.Empty);
 
        private readonly T? _one;
        private readonly ImmutableArray<T> _many;
 
        public OneOrMany(T one)
        {
            _one = one;
            _many = default;
        }
 
        public OneOrMany(ImmutableArray<T> many)
        {
            if (many.IsDefault)
            {
                throw new ArgumentNullException(nameof(many));
            }
 
            _one = default;
            _many = many;
        }
 
        [MemberNotNullWhen(true, nameof(_one))]
        private bool HasOne
            => _many.IsDefault;
 
        public T this[int index]
        {
            get
            {
                if (HasOne)
                {
                    if (index != 0)
                    {
                        throw new IndexOutOfRangeException();
                    }
 
                    return _one;
                }
                else
                {
                    return _many[index];
                }
            }
        }
 
        public int Count
            => HasOne ? 1 : _many.Length;
 
        public bool IsEmpty
            => Count == 0;
 
        public OneOrMany<T> Add(T one)
        {
            var builder = ArrayBuilder<T>.GetInstance(this.Count + 1);
            if (HasOne)
            {
                builder.Add(_one);
            }
            else
            {
                builder.AddRange(_many);
            }
 
            builder.Add(one);
            return new OneOrMany<T>(builder.ToImmutableAndFree());
        }
 
        public bool Contains(T item)
        {
            if (HasOne)
                return EqualityComparer<T>.Default.Equals(item, _one);
 
            foreach (var value in _many)
            {
                if (EqualityComparer<T>.Default.Equals(item, value))
                    return true;
            }
 
            return false;
        }
 
        public OneOrMany<T> RemoveAll(T item)
        {
            if (HasOne)
            {
                return EqualityComparer<T>.Default.Equals(item, _one) ? default : this;
            }
 
            var builder = ArrayBuilder<T>.GetInstance();
 
            foreach (var value in _many)
            {
                if (!EqualityComparer<T>.Default.Equals(item, value))
                    builder.Add(value);
            }
 
            if (builder.Count == 0)
            {
                builder.Free();
                return default;
            }
 
            return builder.Count == Count ? this : new OneOrMany<T>(builder.ToImmutableAndFree());
        }
 
        public OneOrMany<TResult> Select<TResult>(Func<T, TResult> selector)
        {
            return HasOne ?
                OneOrMany.Create(selector(_one)) :
                OneOrMany.Create(_many.SelectAsArray(selector));
        }
 
        public OneOrMany<TResult> Select<TResult, TArg>(Func<T, TArg, TResult> selector, TArg arg)
        {
            return HasOne ?
                OneOrMany.Create(selector(_one, arg)) :
                OneOrMany.Create(_many.SelectAsArray(selector, arg));
        }
 
        public T? FirstOrDefault(Func<T, bool> predicate)
        {
            if (HasOne)
            {
                return predicate(_one) ? _one : default;
            }
 
            foreach (var item in _many)
            {
                if (predicate(item))
                {
                    return item;
                }
            }
 
            return default;
        }
 
        public T? FirstOrDefault<TArg>(Func<T, TArg, bool> predicate, TArg arg)
        {
            if (HasOne)
            {
                return predicate(_one, arg) ? _one : default;
            }
 
            foreach (var item in _many)
            {
                if (predicate(item, arg))
                {
                    return item;
                }
            }
 
            return default;
        }
 
        public Enumerator GetEnumerator()
            => new(this);
 
        internal struct Enumerator
        {
            private readonly OneOrMany<T> _collection;
            private int _index;
 
            internal Enumerator(OneOrMany<T> collection)
            {
                _collection = collection;
                _index = -1;
            }
 
            public bool MoveNext()
            {
                _index++;
                return _index < _collection.Count;
            }
 
            public T Current => _collection[_index];
        }
    }
 
    internal static class OneOrMany
    {
        public static OneOrMany<T> Create<T>(T one)
            => new OneOrMany<T>(one);
 
        public static OneOrMany<T> Create<T>(ImmutableArray<T> many)
            => new OneOrMany<T>(many);
    }
}