Code:
/ Dotnetfx_Vista_SP2 / Dotnetfx_Vista_SP2 / 8.0.50727.4016 / DEVDIV / depot / DevDiv / releases / Orcas / QFE / ndp / fx / src / DataEntity / System / Data / Objects / ELinq / MethodCallTranslator.cs / 1 / MethodCallTranslator.cs
//---------------------------------------------------------------------- //// Copyright (c) Microsoft Corporation. All rights reserved. // // // @owner [....], [....] //--------------------------------------------------------------------- using System.Data.Common.CommandTrees; using System.Collections.Generic; using CqtExpression = System.Data.Common.CommandTrees.DbExpression; using LinqExpression = System.Linq.Expressions.Expression; using System.Diagnostics; using System.Data.Metadata.Edm; using System.Linq.Expressions; using System.Reflection; using System.Linq; using System.Data.Entity; using System.Data.Common; using System.Globalization; namespace System.Data.Objects.ELinq { internal sealed partial class ExpressionConverter { ////// Translates System.Linq.Expression.MethodCallExpression to System.Data.Common.CommandTrees.DbExpression /// private sealed class MethodCallTranslator : TypedTranslator{ internal MethodCallTranslator() : base(ExpressionType.Call) { } protected override CqtExpression TypedTranslate(ExpressionConverter parent, MethodCallExpression linq) { // check if this is a known sequence method SequenceMethod sequenceMethod; SequenceMethodTranslator sequenceTranslator; if (ReflectionUtil.TryIdentifySequenceMethod(linq.Method, out sequenceMethod) && s_sequenceTranslators.TryGetValue(sequenceMethod, out sequenceTranslator)) { return sequenceTranslator.Translate(parent, linq, sequenceMethod); } // check if this is a known method CallTranslator callTranslator; if (TryGetCallTranslator(linq.Method, out callTranslator)) { return callTranslator.Translate(parent, linq); } // check if this is an ObjectQuery<> builder method Type declaringType = linq.Method.DeclaringType; if (linq.Method.IsPublic && null != declaringType && declaringType.IsGenericType && typeof(ObjectQuery<>) == declaringType.GetGenericTypeDefinition()) { ObjectQueryCallTranslator builderTranslator; if (s_objectQueryTranslators.TryGetValue(linq.Method.Name, out builderTranslator)) { return builderTranslator.Translate(parent, linq); } } // fall back on the default translator return s_defaultTranslator.Translate(parent, linq); } #region Static members and initializers private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; // initialize fall-back translator private static readonly CallTranslator s_defaultTranslator = new DefaultTranslator(); private static readonly Dictionary s_methodTranslators = InitializeMethodTranslators(); private static readonly Dictionary s_sequenceTranslators = InitializeSequenceMethodTranslators(); private static readonly Dictionary s_objectQueryTranslators = InitializeObjectQueryTranslators(); private static bool s_vbMethodsInitialized; private static readonly object s_vbInitializerLock = new object(); private static Dictionary InitializeMethodTranslators() { // initialize translators for specific methods (e.g., Int32.op_Equality) Dictionary methodTranslators = new Dictionary (); foreach (CallTranslator translator in GetCallTranslators()) { foreach (MethodInfo method in translator.Methods) { methodTranslators.Add(method, translator); } } return methodTranslators; } private static Dictionary InitializeSequenceMethodTranslators() { // initialize translators for sequence methods (e.g., Sequence.Select) Dictionary sequenceTranslators = new Dictionary (); foreach (SequenceMethodTranslator translator in GetSequenceMethodTranslators()) { foreach (SequenceMethod method in translator.Methods) { sequenceTranslators.Add(method, translator); } } return sequenceTranslators; } private static Dictionary InitializeObjectQueryTranslators() { // initialize translators for object query methods (e.g. ObjectQuery .OfType (), ObjectQuery.Include(string) ) Dictionary objectQueryCallTranslators = new Dictionary (StringComparer.Ordinal); foreach (ObjectQueryCallTranslator translator in GetObjectQueryCallTranslators()) { objectQueryCallTranslators[translator.MethodName] = translator; } return objectQueryCallTranslators; } /// /// Tries to get a translator for the given method info. /// If the given method info corresponds to a Visual Basic property, /// it also initializes the Visual Basic translators if they have not been initialized /// /// /// ///private static bool TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator) { if (s_methodTranslators.TryGetValue(methodInfo, out callTranslator)) { return true; } // check if this is the visual basic assembly if (s_visualBasicAssemblyFullName == methodInfo.DeclaringType.Assembly.FullName) { lock (s_vbInitializerLock) { if (!s_vbMethodsInitialized) { InitializeVBMethods(methodInfo.DeclaringType.Assembly); s_vbMethodsInitialized = true; } // try again return s_methodTranslators.TryGetValue(methodInfo, out callTranslator); } } callTranslator = null; return false; } private static void InitializeVBMethods(Assembly vbAssembly) { Debug.Assert(!s_vbMethodsInitialized); foreach (CallTranslator translator in GetVisualBasicCallTranslators(vbAssembly)) { foreach (MethodInfo method in translator.Methods) { s_methodTranslators.Add(method, translator); } } } private static IEnumerable GetVisualBasicCallTranslators(Assembly vbAssembly) { yield return new VBCanonicalFunctionDefaultTranslator(vbAssembly); yield return new VBCanonicalFunctionRenameTranslator(vbAssembly); yield return new VBDatePartTranslator(vbAssembly); } private static IEnumerable GetCallTranslators() { yield return new CanonicalFunctionDefaultTranslator(); yield return new ContainsTranslator(); yield return new StartsWithTranslator(); yield return new EndsWithTranslator(); yield return new IndexOfTranslator(); yield return new SubstringTranslator(); yield return new RemoveTranslator(); yield return new InsertTranslator(); yield return new IsNullOrEmptyTranslator(); yield return new StringConcatTranslator(); yield return new TrimStartTranslator(); yield return new TrimEndTranslator(); } private static IEnumerable GetSequenceMethodTranslators() { yield return new ConcatTranslator(); yield return new UnionTranslator(); yield return new IntersectTranslator(); yield return new ExceptTranslator(); yield return new DistinctTranslator(); yield return new WhereTranslator(); yield return new SelectTranslator(); yield return new OrderByTranslator(); yield return new OrderByDescendingTranslator(); yield return new ThenByTranslator(); yield return new ThenByDescendingTranslator(); yield return new SelectManyTranslator(); yield return new AnyTranslator(); yield return new AnyPredicateTranslator(); yield return new AllTranslator(); yield return new JoinTranslator(); yield return new GroupByTranslator(); yield return new MaxTranslator(); yield return new MinTranslator(); yield return new AverageTranslator(); yield return new SumTranslator(); yield return new CountTranslator(); yield return new LongCountTranslator(); yield return new CastMethodTranslator(); yield return new GroupJoinTranslator(); yield return new OfTypeTranslator(); yield return new SingleTranslatorNotSupported(); yield return new PassthroughTranslator(); yield return new FirstTranslator(); yield return new FirstPredicateTranslator(); yield return new FirstOrDefaultTranslator(); yield return new FirstOrDefaultPredicateTranslator(); yield return new TakeTranslator(); yield return new SkipTranslator(); } private static IEnumerable GetObjectQueryCallTranslators() { yield return new ObjectQueryBuilderDistinctTranslator(); yield return new ObjectQueryBuilderExceptTranslator(); yield return new ObjectQueryBuilderFirstTranslator(); yield return new ObjectQueryIncludeTranslator(); yield return new ObjectQueryBuilderIntersectTranslator(); yield return new ObjectQueryBuilderOfTypeTranslator(); yield return new ObjectQueryBuilderUnionTranslator(); } #endregion #region Method translators private abstract class CallTranslator { private readonly IEnumerable _methods; protected CallTranslator(params MethodInfo[] methods) { _methods = methods; } protected CallTranslator(IEnumerable methods) { _methods = methods; } internal IEnumerable Methods { get { return _methods; } } internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); public override string ToString() { return GetType().Name; } } private abstract class ObjectQueryCallTranslator : CallTranslator { private readonly string _methodName; protected ObjectQueryCallTranslator(string methodName) { _methodName = methodName; } internal string MethodName { get { return _methodName; } } } private abstract class ObjectQueryBuilderCallTranslator : ObjectQueryCallTranslator { private readonly SequenceMethodTranslator _translator; protected ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent) : base(methodName) { bool translatorFound = s_sequenceTranslators.TryGetValue(sequenceEquivalent, out _translator); Debug.Assert(translatorFound, "Translator not found for " + sequenceEquivalent.ToString()); } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return _translator.Translate(parent, call); } } private sealed class ObjectQueryBuilderUnionTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderUnionTranslator() : base("Union", SequenceMethod.Union) { } } private sealed class ObjectQueryBuilderIntersectTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderIntersectTranslator() : base("Intersect", SequenceMethod.Intersect) { } } private sealed class ObjectQueryBuilderExceptTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderExceptTranslator() : base("Except", SequenceMethod.Except) { } } private sealed class ObjectQueryBuilderDistinctTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderDistinctTranslator() : base("Distinct", SequenceMethod.Distinct) { } } private sealed class ObjectQueryBuilderOfTypeTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderOfTypeTranslator() : base("OfType", SequenceMethod.OfType) { } } private sealed class ObjectQueryBuilderFirstTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderFirstTranslator() : base("First", SequenceMethod.First) { } } private sealed class ObjectQueryIncludeTranslator : ObjectQueryCallTranslator { internal ObjectQueryIncludeTranslator() : base("Include") { } internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(string)), "Invalid Include arguments?"); CqtExpression queryExpression = parent.TranslateExpression(call.Object); Span span; if (!parent.TryGetSpan(queryExpression, out span)) { span = null; } CqtExpression arg = parent.TranslateExpression(call.Arguments[0]); string includePath = null; if (arg.ExpressionKind == DbExpressionKind.Constant) { includePath = (string)((DbConstantExpression)arg).Value; } else { // The 'Include' method implementation on ELinqQueryState creates // a method call expression with a string constant argument taking // the value of the string argument passed to ObjectQuery.Include, // and so this is the only supported pattern here. throw EntityUtil.NotSupported(Entity.Strings.ELinq_UnsupportedInclude); } return parent.AddSpanMapping(queryExpression, Span.IncludeIn(span, includePath)); } } private sealed class DefaultTranslator : CallTranslator { internal DefaultTranslator() : base() { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { MethodInfo suggestedMethodInfo; if (TryGetAlternativeMethod(call.Method, out suggestedMethodInfo)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethodSuggestedAlternative(call.Method, suggestedMethodInfo)); } //The default error message throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethod(call.Method)); } #region Static Members private static readonly Dictionary s_alternativeMethods = InitializeAlternateMethodInfos(); private static bool s_vbMethodsInitialized; private static readonly object s_vbInitializerLock = new object(); /// /// Tries to check whether there is an alternative method suggested insted of the given unsupported one. /// /// /// ///private static bool TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo) { if (s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo)) { return true; } // check if this is the visual basic assembly if (s_visualBasicAssemblyFullName == originalMethodInfo.DeclaringType.Assembly.FullName) { lock (s_vbInitializerLock) { if (!s_vbMethodsInitialized) { InitializeVBMethods(originalMethodInfo.DeclaringType.Assembly); s_vbMethodsInitialized = true; } // try again return s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo); } } suggestedMethodInfo = null; return false; } /// /// Initializes the dictionary of alternative methods. /// Currently, it simply initializes an empty dictionary. /// ///private static Dictionary InitializeAlternateMethodInfos() { return new Dictionary (1); } /// /// Populates the dictionary of alternative methods with the VB methods /// /// private static void InitializeVBMethods(Assembly vbAssembly) { Debug.Assert(!s_vbMethodsInitialized); //Handle { Mid(arg1, ar2), Mid(arg1, arg2, arg3) } Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); s_alternativeMethods.Add( stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null), stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int), typeof(int) }, null)); } #endregion } private sealed class CanonicalFunctionDefaultTranslator : CallTranslator { internal CanonicalFunctionDefaultTranslator() : base(GetMethods()) { } private static IEnumerableGetMethods() { //Math functions yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); //Decimal functions yield return typeof(Decimal).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Decimal).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); //String functions yield return typeof(String).GetMethod("Replace", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(String), typeof(String) }, null); yield return typeof(String).GetMethod("ToLower", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); yield return typeof(String).GetMethod("ToUpper", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); yield return typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); } // Default translator for method calls into canonical functions. // Translation: // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) // this.MethodName(arg1, arg2, .., argn) -> MethodName(this, arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { LinqExpression[] linqArguments; if (!call.Method.IsStatic) { Debug.Assert(call.Object != null, "Instance method without this"); List arguments = new List (call.Arguments.Count + 1); arguments.Add(call.Object); arguments.AddRange(call.Arguments); linqArguments = arguments.ToArray(); } else { linqArguments = call.Arguments.ToArray(); } return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, linqArguments); } } #region System.String Method Translators private sealed class ContainsTranslator : CallTranslator { internal ContainsTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Contains", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.Contains(argument) -> IndexOf(argument, object) > 0 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); DbComparisonExpression comparisonExpression = parent._commandTree.CreateGreaterThanExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(0)); return comparisonExpression; } } private sealed class IndexOfTranslator : CallTranslator { internal IndexOfTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("IndexOf", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // IndexOf(arg1) -> IndexOf(arg1, this) - 1 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IndexOf"); DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); CqtExpression minusExpression = parent._commandTree.CreateMinusExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(1)); return minusExpression; } } private sealed class StartsWithTranslator : CallTranslator { internal StartsWithTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("StartsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.StartsWith(argument) -> IndexOf(argument, object) == 1 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(1)); return comparisonExpression; } } private sealed class EndsWithTranslator : CallTranslator { internal EndsWithTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("EndsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.EndsWith(argument) -> Right(object, Length(argument)) == argument internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression lengthExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]); DbExpression rightExpression = parent.CreateCanonicalFunction(ExpressionConverter.Right, call, parent.TranslateExpression(call.Object), lengthExpression); DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression( rightExpression, parent.TranslateExpression(call.Arguments[0])); return comparisonExpression; } } private sealed class SubstringTranslator : CallTranslator { internal SubstringTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); } // Translation: // Substring(arg1) -> Substring(this, arg1+1, Length(this) - arg1)) // Substring(arg1, arg2) -> Substring(this, arg1+1, arg2) // internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Substring"); CqtExpression target = parent.TranslateExpression(call.Object); CqtExpression fromIndex = parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent._commandTree.CreateConstantExpression(1)); CqtExpression length; if (call.Arguments.Count == 1) { length = parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent.TranslateExpression(call.Arguments[0])); } else { length = parent.TranslateExpression(call.Arguments[1]); } CqtExpression substringExpression = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, target, fromIndex, length); return substringExpression; } } private sealed class RemoveTranslator : CallTranslator { internal RemoveTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); } // Translation: // Remove(arg1) -> Substring(this, 1, arg1) // Remove(arg1, arg2) -> Concat(Substring(this, 1, arg1) , Substring(this, arg1 + arg2 + 1, Length(this) - (arg1 + arg2))) // Remove(arg1, arg2) is only supported if arg2 is a non-negative integer internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Remove"); //Substring(this, 1, arg1) CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreateConstantExpression(1), parent.TranslateExpression(call.Arguments[0])); //Concat(result, Substring(this, (arg1 + arg2) +1, Length(this) - (arg1 + arg2))) if (call.Arguments.Count == 2) { //If there are two arguemtns, we only support cases when the second one translates to a non-negative constant CqtExpression translatedArgument1 = parent.TranslateExpression(call.Arguments[1]); if (!IsNonNegativeIntegerConstant(translatedArgument1)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedStringRemoveCase(call.Method, call.Method.GetParameters()[1].Name)); } // Build the second substring // (arg1 + arg2) +1 CqtExpression substringStartIndex = parent._commandTree.CreatePlusExpression( parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), translatedArgument1), parent._commandTree.CreateConstantExpression(1)); // Length(this) - (arg1 + arg2) CqtExpression substringLength = parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent.TranslateExpression(call.Arguments[1]))); // Substring(this, substringStartIndex, substringLenght) CqtExpression secondSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), substringStartIndex, substringLength); // result = Concat (result, secondSubstring) result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, secondSubstring); } return result; } private static bool IsNonNegativeIntegerConstant(CqtExpression argument) { // Check whether it is a constant of type Int32 if (argument.ExpressionKind != DbExpressionKind.Constant || !TypeSemantics.IsPrimitiveType(argument.ResultType, PrimitiveTypeKind.Int32)) { return false; } // Check whether its value is non-negative DbConstantExpression constantExpression = (DbConstantExpression)argument; int value = (int)constantExpression.Value; if (value < 0) { return false; } return true; } } private sealed class InsertTranslator : CallTranslator { internal InsertTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Insert", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(string) }, null); } // Translation: // Insert(startIndex, value) -> Concat(Concat(Substring(this, 1, startIndex), value), Substring(this, startIndex+1, Length(this) - startIndex)) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 2, "Expecting 2 arguments for String.Insert"); //Substring(this, 1, startIndex) CqtExpression firstSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreateConstantExpression(1), parent.TranslateExpression(call.Arguments[0])); //Substring(this, startIndex+1, Length(this) - startIndex) CqtExpression secondSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent._commandTree.CreateConstantExpression(1)), parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent.TranslateExpression(call.Arguments[0]))); // result = Concat( Concat (firstSubstring, value), secondSubstring ) CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, firstSubstring, parent.TranslateExpression(call.Arguments[1])), secondSubstring); return result; } } private sealed class IsNullOrEmptyTranslator : CallTranslator { internal IsNullOrEmptyTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("IsNullOrEmpty", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); } // Translation: // IsNullOrEmpty(value) -> (IsNull(value)) OR Length(value) = 0 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IsNullOrEmpty"); //IsNull(value) CqtExpression isNullExpression = parent._commandTree.CreateIsNullExpression( parent.TranslateExpression(call.Arguments[0])); //Length(value) = 0 CqtExpression emptyStringExpression = parent._commandTree.CreateEqualsExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]), parent._commandTree.CreateConstantExpression(0)); CqtExpression result = parent._commandTree.CreateOrExpression(isNullExpression, emptyStringExpression); return result; } } private sealed class StringConcatTranslator : CallTranslator { internal StringConcatTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string) }, null); yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string) }, null); yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null); } // Translation: // Concat (arg1, arg2) -> Concat(arg1, arg2) // Concat (arg1, arg2, arg3) -> Concat(Concat(arg1, arg2), arg3) // Concat (arg1, arg2, arg3, arg4) -> Concat(Concat(Concat(arg1, arg2), arg3), arg4) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count >= 2 && call.Arguments.Count <= 4, "Expecting between 2 and 4 arguments for String.Concat"); CqtExpression result = parent.TranslateExpression(call.Arguments[0]); for (int argIndex = 1; argIndex < call.Arguments.Count; argIndex++) { // result = Concat(result, arg[argIndex]) result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, parent.TranslateExpression(call.Arguments[argIndex])); } return result; } } private abstract class TrimStartTrimEndBaseTranslator : CallTranslator { private string _canonicalFunctionName; protected TrimStartTrimEndBaseTranslator(IEnumerable methods, string canonicalFunctionName) : base(methods) { _canonicalFunctionName = canonicalFunctionName; } // Translation: // object.MethodName -> CanonicalFunctionName(object) // Supported only if the argument is an empty array. internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (!IsEmptyArray(call.Arguments[0])) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedTrimStartTrimEndCase(call.Method)); } return parent.TranslateIntoCanonicalFunction(_canonicalFunctionName, call, call.Object); } internal static bool IsEmptyArray(LinqExpression expression) { if (expression.NodeType != ExpressionType.NewArrayInit) { return false; } NewArrayExpression newArray = (NewArrayExpression)expression; if (newArray.Expressions.Count != 0) { return false; } return true; } } private sealed class TrimStartTranslator : TrimStartTrimEndBaseTranslator { internal TrimStartTranslator() : base(GetMethods(), ExpressionConverter.LTrim) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("TrimStart", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); } } private sealed class TrimEndTranslator : TrimStartTrimEndBaseTranslator { internal TrimEndTranslator() : base(GetMethods(), ExpressionConverter.RTrim) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("TrimEnd", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); } } #endregion #region Visual Basic Specific Translators private sealed class VBCanonicalFunctionDefaultTranslator : CallTranslator { private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; internal VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } private static IEnumerable GetMethods(Assembly vbAssembly) { //Strings Types Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); yield return stringsType.GetMethod("Trim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("LTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("RTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("Left", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); yield return stringsType.GetMethod("Right", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); //DateTimeType Type dateTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); yield return dateTimeType.GetMethod("Year", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Month", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Day", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Hour", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Minute", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Second", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); } // Default translator for vb static method calls into canonical functions. // Translation: // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, call.Arguments.ToArray()); } } private sealed class VBCanonicalFunctionRenameTranslator : CallTranslator { private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; private static readonly Dictionary s_methodNameMap = new Dictionary (4); internal VBCanonicalFunctionRenameTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } private static IEnumerable GetMethods(Assembly vbAssembly) { //Strings Types Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); yield return GetMethod(stringsType, "Len", ExpressionConverter.Length, new Type[] { typeof(string) }); yield return GetMethod(stringsType, "Mid", ExpressionConverter.Substring, new Type[] { typeof(string), typeof(int), typeof(int) }); yield return GetMethod(stringsType, "UCase", ExpressionConverter.ToUpper, new Type[] { typeof(string) }); yield return GetMethod(stringsType, "LCase", ExpressionConverter.ToLower, new Type[] { typeof(string) }); } private static MethodInfo GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes) { MethodInfo methodInfo = declaringType.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static, null, argumentTypes, null); s_methodNameMap.Add(methodInfo, canonicalFunctionName); return methodInfo; } // Translator for static method calls into canonical functions when only the name of the canonical function // is different from the name of the method, but the argumens match. // Translation: // MethodName(arg1, arg2, .., argn) -> CanonicalFunctionName(arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return parent.TranslateIntoCanonicalFunction(s_methodNameMap[call.Method], call, call.Arguments.ToArray()); } } private sealed class VBDatePartTranslator : CallTranslator { private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; private const string s_DateIntervalFullName = "Microsoft.VisualBasic.DateInterval"; private const string s_FirstDayOfWeekFullName = "Microsoft.VisualBasic.FirstDayOfWeek"; private const string s_FirstWeekOfYearFullName = "Microsoft.VisualBasic.FirstWeekOfYear"; private static HashSet s_supportedIntervals; internal VBDatePartTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } static VBDatePartTranslator() { s_supportedIntervals = new HashSet (); s_supportedIntervals.Add(ExpressionConverter.Year); s_supportedIntervals.Add(ExpressionConverter.Month); s_supportedIntervals.Add(ExpressionConverter.Day); s_supportedIntervals.Add(ExpressionConverter.Hour); s_supportedIntervals.Add(ExpressionConverter.Minute); s_supportedIntervals.Add(ExpressionConverter.Second); } private static IEnumerable GetMethods(Assembly vbAssembly) { Type dateAndTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); Type dateIntervalEnum = vbAssembly.GetType(s_DateIntervalFullName); Type firstDayOfWeekEnum = vbAssembly.GetType(s_FirstDayOfWeekFullName); Type firstWeekOfYearEnum = vbAssembly.GetType(s_FirstWeekOfYearFullName); yield return dateAndTimeType.GetMethod("DatePart", BindingFlags.Public | BindingFlags.Static, null, new Type[] { dateIntervalEnum, typeof(DateTime), firstDayOfWeekEnum, firstWeekOfYearEnum }, null); } // Translation: // DatePart(DateInterval, date, arg3, arg4) -> 'DateInterval'(date) // Note: it is only supported for the values of DateInterval listed in s_supportedIntervals. internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 4, "Expecting 4 arguments for Microsoft.VisualBasic.DateAndTime.DatePart"); ConstantExpression intervalLinqExpression = call.Arguments[0] as ConstantExpression; if (intervalLinqExpression == null) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartNonConstantInterval(call.Method, call.Method.GetParameters()[0].Name)); } string intervalValue = intervalLinqExpression.Value.ToString(); if (!s_supportedIntervals.Contains(intervalValue)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartInvalidInterval(call.Method, call.Method.GetParameters()[0].Name, intervalValue)); } CqtExpression result = parent.TranslateIntoCanonicalFunction(intervalValue, call, call.Arguments[1]); return result; } } #endregion #endregion #region Sequence method translators private abstract class SequenceMethodTranslator { private readonly IEnumerable _methods; protected SequenceMethodTranslator(params SequenceMethod[] methods) { _methods = methods; } internal IEnumerable Methods { get { return _methods; } } internal virtual CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) { return Translate(parent, call); } internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); public override string ToString() { return GetType().Name; } } private abstract class PagingTranslator : UnarySequenceMethodTranslator { protected PagingTranslator(params SequenceMethod[] methods) : base(methods) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // translate count expression Debug.Assert(call.Arguments.Count == 2, "Skip and Take must have 2 arguments"); LinqExpression linqCount = call.Arguments[1]; CqtExpression count = parent.TranslateExpression(linqCount); // remove projections at the root of the expression and then reapply after apply paging operator DbProjectExpression projection = null; if (operand.ExpressionKind == DbExpressionKind.Project) { projection = (DbProjectExpression)operand; operand = projection.Input.Expression; } // translate paging expression DbExpression result = TranslatePagingOperator(parent, operand, count); // reapply project as necessary if (null != projection) { projection.Input.Expression = result; result = projection; } return result; } protected abstract CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count); } private sealed class TakeTranslator : PagingTranslator { internal TakeTranslator() : base(SequenceMethod.Take) { } protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) { return parent.Limit(operand, count); } } private sealed class SkipTranslator : PagingTranslator { internal SkipTranslator() : base(SequenceMethod.Skip) { } protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) { // skip requires a sorted input if (operand.ExpressionKind != DbExpressionKind.Sort) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_SkipWithoutOrder); } DbSortExpression sortedOperand = (DbSortExpression)operand; Span sortedSpan = null; bool hadSpan = parent.TryGetSpan(sortedOperand, out sortedSpan); // generate a skip statement with the sort order of the original sorted input. DbSkipExpression skip = parent.Skip(sortedOperand.Input, sortedOperand.SortOrder, count); // If the original DbSortExpression had Span information, then this is applied // to the newly created DbSkipExpression before returning. if (hadSpan) { parent.AddSpanMapping(skip, sortedSpan); } return skip; } } private sealed class SingleTranslatorNotSupported : SequenceMethodTranslator { internal SingleTranslatorNotSupported() : base(SequenceMethod.Single, SequenceMethod.SinglePredicate, SequenceMethod.SingleOrDefault, SequenceMethod.SingleOrDefaultPredicate) { } internal override System.Data.Common.CommandTrees.DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedSingle); } } private sealed class JoinTranslator : SequenceMethodTranslator { internal JoinTranslator() : base(SequenceMethod.Join) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(5 == call.Arguments.Count); // get expressions describing inputs to the join CqtExpression outer = parent.TranslateSet(call.Arguments[0]); CqtExpression inner = parent.TranslateSet(call.Arguments[1]); // get expressions describing key selectors LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); // translator key selectors DbExpressionBinding outerBinding; DbExpressionBinding innerBinding; CqtExpression outerKeySelector = parent.TranslateLambda(outerLambda, outer, out outerBinding); CqtExpression innerKeySelector = parent.TranslateLambda(innerLambda, inner, out innerBinding); // construct join expression if (!TypeSemantics.IsEqualComparable(outerKeySelector.ResultType) || !TypeSemantics.IsEqualComparable(innerKeySelector.ResultType)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } DbJoinExpression join = parent._commandTree.CreateInnerJoinExpression( outerBinding, innerBinding, parent.CreateEqualsExpression(outerKeySelector, innerKeySelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type)); DbExpressionBinding joinBinding = parent._commandTree.CreateExpressionBinding(join); // get selector expression LambdaExpression selectorLambda = parent.GetLambdaExpression(call, 4); // create property expressions for the inner and outer DbPropertyExpression joinOuter = parent._commandTree.CreatePropertyExpression( outerBinding.VariableName, joinBinding.Variable); DbPropertyExpression joinInner = parent._commandTree.CreatePropertyExpression( innerBinding.VariableName, joinBinding.Variable); // push outer and inner join parts into the binding scope (the order // is irrelevant because the binding context matches based on parameter // reference rather than ordinal) parent._bindingContext.PushBindingScope( new Binding(selectorLambda.Parameters[0], joinOuter), new Binding(selectorLambda.Parameters[1], joinInner)); // translate join selector CqtExpression selector = parent.TranslateExpression(selectorLambda.Body); // pop binding scope parent._bindingContext.PopBindingScope(); return parent._commandTree.CreateProjectExpression(joinBinding, selector); } } private abstract class BinarySequenceMethodTranslator : SequenceMethodTranslator { protected BinarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (null != call.Object) { // instance method Debug.Assert(1 == call.Arguments.Count); CqtExpression left = parent.TranslateSet(call.Object); CqtExpression right = parent.TranslateSet(call.Arguments[0]); return TranslateBinary(parent, left, right); } else { // static extension method Debug.Assert(2 == call.Arguments.Count); CqtExpression left = parent.TranslateSet(call.Arguments[0]); CqtExpression right = parent.TranslateSet(call.Arguments[1]); return TranslateBinary(parent, left, right); } } protected abstract CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right); } private class ConcatTranslator : BinarySequenceMethodTranslator { internal ConcatTranslator() : base(SequenceMethod.Concat) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.UnionAll(left, right); } } private sealed class UnionTranslator : BinarySequenceMethodTranslator { internal UnionTranslator() : base(SequenceMethod.Union) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Distinct(parent.UnionAll(left, right)); } } private sealed class IntersectTranslator : BinarySequenceMethodTranslator { internal IntersectTranslator() : base(SequenceMethod.Intersect) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Intersect(left, right); } } private sealed class ExceptTranslator : BinarySequenceMethodTranslator { internal ExceptTranslator() : base(SequenceMethod.Except) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Except(left, right); } } private abstract class AggregateTranslator : SequenceMethodTranslator { private readonly string _functionName; private readonly bool _takesPredicate; protected AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods) : base(methods) { _takesPredicate = takesPredicate; _functionName = functionName; } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { bool isUnary = 1 == call.Arguments.Count; Debug.Assert(isUnary || 2 == call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Arguments[0]); // Facet information for the return type cannot help in determining the appropriate function overload // since no constraints on the return value are known. TypeUsage returnType = parent.GetValueLayerType(call.Type); LambdaExpression lambda = null; //We save the original operand for the optimized translation CqtExpression originalOperand = operand; if (!isUnary) { lambda = parent.GetLambdaExpression(call, 1); DbExpressionBinding sourceBinding; CqtExpression cqtLambda = parent.TranslateLambda(lambda, operand, out sourceBinding); if (_takesPredicate) { // treat the lambda as a filter operand = parent.Filter(sourceBinding, cqtLambda); } else { // treat the lambda as a selector operand = parent._commandTree.CreateProjectExpression(sourceBinding, cqtLambda); } } operand = WrapCollectionOperand(parent, operand, returnType); DbGroupExpressionBinding operandBinding = parent._commandTree.CreateGroupExpressionBinding(operand); EdmFunction function = FindFunction(parent, call, returnType); // create aggregate List > keys = new List >(0); // no key List > aggregates = new List >(1); const string aggregateName = "Aggregate"; aggregates.Add(new KeyValuePair (aggregateName, // name is arbitrary (there is only one in this context) parent._commandTree.CreateFunctionAggregate(function, operandBinding.GroupVariable))); DbGroupByExpression aggregate = parent._commandTree.CreateGroupByExpression( operandBinding, keys, aggregates); DbExpressionBinding aggregateBinding = parent._commandTree.CreateExpressionBinding(aggregate); // project result DbPropertyExpression property = parent._commandTree.CreatePropertyExpression( aggregateName, aggregateBinding.Variable); DbProjectExpression projection = parent._commandTree.CreateProjectExpression( aggregateBinding, parent.AlignTypes(property, call.Type)); // return a single element to represent the projection DbElementExpression element = parent._commandTree.CreateElementExpression(projection); // Try to create and log an optimized translation TryCreateOptimizedTranslation(parent, lambda, originalOperand, function, element); return element; } // If the function is over a group by, it tries to incorporate the aggregate function into the group by. // If it does, it gives the aggregate an alias and it returns it through private void TryCreateOptimizedTranslation(ExpressionConverter parent, LambdaExpression lambda, CqtExpression operand, EdmFunction function, CqtExpression originalTranslation) { //Aggregates that take predicates as arguments cannot be incorporated into a group by if (_takesPredicate && (lambda != null)) { return; } //Check whether the operand is a property over an output of grouping if (operand.ExpressionKind != DbExpressionKind.Property) { return; } DbPropertyExpression propertyExpression = (DbPropertyExpression)operand; if (propertyExpression.Instance.ExpressionKind != DbExpressionKind.VariableReference) { return; } DbVariableReferenceExpression inputVarRef = (DbVariableReferenceExpression)propertyExpression.Instance; //If the input corresponding to the var ref has an optimized translation, generate an alternate translation for this as well. CqtExpression input; if (!parent._variableNameToInputExpression.TryGetValue(inputVarRef.VariableName, out input)) { return; } DbGroupByTemplate optimizedTranslationOfInput; if (!parent._groupByDefaultToOptimizedTranslationMap.TryGetValue(input, out optimizedTranslationOfInput)) { return; } Debug.Assert(TypeSemantics.IsCollectionType(function.Parameters[0].TypeUsage), "Aggregates should always have collection arguments"); TypeUsage elementType = TypeHelpers.GetElementTypeUsage(function.Parameters[0].TypeUsage); // The aggregate can be added to the list of aggregates. CqtExpression aggregateArgument = optimizedTranslationOfInput.Input.GroupVariable; if (lambda != null) { aggregateArgument = parent.TranslateLambda(lambda, aggregateArgument); } string optimizedTranslationAlias = String.Format(CultureInfo.InvariantCulture, "Aggregate{0}", optimizedTranslationOfInput.Aggregates.Count); optimizedTranslationOfInput.Aggregates.Add(new KeyValuePair (optimizedTranslationAlias, parent._commandTree.CreateFunctionAggregate(function, WrapNonCollectionOperand(parent, aggregateArgument, elementType)))); //log the alias parent._aggregateDefaultTranslationToOptimizedTranslationInfoMap.Add(originalTranslation, new KeyValuePair (optimizedTranslationOfInput, optimizedTranslationAlias)); } // If necessary, wraps the operand to ensure the appropriate aggregate overload is called protected virtual CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // check if the operand needs to be wrapped to ensure the correct function overload is called if (!TypeUsageEquals(returnType, ((CollectionType)operand.ResultType.EdmType).TypeUsage)) { DbExpressionBinding operandCastBinding = parent._commandTree.CreateExpressionBinding(operand); DbProjectExpression operandCastProjection = parent._commandTree.CreateProjectExpression( operandCastBinding, parent._commandTree.CreateCastExpression(operandCastBinding.Variable, returnType)); operand = operandCastProjection; } return operand; } // If necessary, wraps the operand to ensure the appropriate aggregate overload is called protected virtual CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { if (!TypeUsageEquals(returnType, operand.ResultType)) { operand = parent._commandTree.CreateCastExpression(operand, returnType); } return operand; } // Finds the best function overload given the expected return type protected virtual EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType) { List argTypes = new List (1); // In general, we use the return type as the parameter type to align LINQ semantics // with SQL semantics, and avoid apparent loss of precision for some LINQ aggregate operators. // (e.g., AVG(1, 2) = 2.0, AVG((double)1, (double)2)) = 1.5) argTypes.Add(argumentType); return parent.FindCanonicalFunction(_functionName, argTypes, true /* isGroupAggregateFunction */, call); } } private sealed class MaxTranslator : AggregateTranslator { internal MaxTranslator() : base("MAX", false, SequenceMethod.Max, SequenceMethod.MaxSelector, SequenceMethod.MaxInt, SequenceMethod.MaxIntSelector, SequenceMethod.MaxDecimal, SequenceMethod.MaxDecimalSelector, SequenceMethod.MaxDouble, SequenceMethod.MaxDoubleSelector, SequenceMethod.MaxLong, SequenceMethod.MaxLongSelector, SequenceMethod.MaxSingle, SequenceMethod.MaxSingleSelector, SequenceMethod.MaxNullableDecimal, SequenceMethod.MaxNullableDecimalSelector, SequenceMethod.MaxNullableDouble, SequenceMethod.MaxNullableDoubleSelector, SequenceMethod.MaxNullableInt, SequenceMethod.MaxNullableIntSelector, SequenceMethod.MaxNullableLong, SequenceMethod.MaxNullableLongSelector, SequenceMethod.MaxNullableSingle, SequenceMethod.MaxNullableSingleSelector) { } } private sealed class MinTranslator : AggregateTranslator { internal MinTranslator() : base("MIN", false, SequenceMethod.Min, SequenceMethod.MinSelector, SequenceMethod.MinDecimal, SequenceMethod.MinDecimalSelector, SequenceMethod.MinDouble, SequenceMethod.MinDoubleSelector, SequenceMethod.MinInt, SequenceMethod.MinIntSelector, SequenceMethod.MinLong, SequenceMethod.MinLongSelector, SequenceMethod.MinNullableDecimal, SequenceMethod.MinSingle, SequenceMethod.MinSingleSelector, SequenceMethod.MinNullableDecimalSelector, SequenceMethod.MinNullableDouble, SequenceMethod.MinNullableDoubleSelector, SequenceMethod.MinNullableInt, SequenceMethod.MinNullableIntSelector, SequenceMethod.MinNullableLong, SequenceMethod.MinNullableLongSelector, SequenceMethod.MinNullableSingle, SequenceMethod.MinNullableSingleSelector) { } } private sealed class AverageTranslator : AggregateTranslator { internal AverageTranslator() : base("AVG", false, SequenceMethod.AverageDecimal, SequenceMethod.AverageDecimalSelector, SequenceMethod.AverageDouble, SequenceMethod.AverageDoubleSelector, SequenceMethod.AverageInt, SequenceMethod.AverageIntSelector, SequenceMethod.AverageLong, SequenceMethod.AverageLongSelector, SequenceMethod.AverageSingle, SequenceMethod.AverageSingleSelector, SequenceMethod.AverageNullableDecimal, SequenceMethod.AverageNullableDecimalSelector, SequenceMethod.AverageNullableDouble, SequenceMethod.AverageNullableDoubleSelector, SequenceMethod.AverageNullableInt, SequenceMethod.AverageNullableIntSelector, SequenceMethod.AverageNullableLong, SequenceMethod.AverageNullableLongSelector, SequenceMethod.AverageNullableSingle, SequenceMethod.AverageNullableSingleSelector) { } } private sealed class SumTranslator : AggregateTranslator { internal SumTranslator() : base("SUM", false, SequenceMethod.SumDecimal, SequenceMethod.SumDecimalSelector, SequenceMethod.SumDouble, SequenceMethod.SumDoubleSelector, SequenceMethod.SumInt, SequenceMethod.SumIntSelector, SequenceMethod.SumLong, SequenceMethod.SumLongSelector, SequenceMethod.SumSingle, SequenceMethod.SumSingleSelector, SequenceMethod.SumNullableDecimal, SequenceMethod.SumNullableDecimalSelector, SequenceMethod.SumNullableDouble, SequenceMethod.SumNullableDoubleSelector, SequenceMethod.SumNullableInt, SequenceMethod.SumNullableIntSelector, SequenceMethod.SumNullableLong, SequenceMethod.SumNullableLongSelector, SequenceMethod.SumNullableSingle, SequenceMethod.SumNullableSingleSelector) { } } private abstract class CountTranslatorBase : AggregateTranslator { protected CountTranslatorBase(string functionName, params SequenceMethod[] methods) : base(functionName, true, methods) { } protected override CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // always count a constant value DbExpressionBinding operandBinding = parent._commandTree.CreateExpressionBinding(operand); DbProjectExpression constantProject = parent._commandTree.CreateProjectExpression( operandBinding, parent._commandTree.CreateTrueExpression()); return constantProject; } protected override CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // always count a constant value DbExpression constantExpression = parent._commandTree.CreateConstantExpression(1); if (!TypeUsageEquals(constantExpression.ResultType, returnType)) { constantExpression = parent._commandTree.CreateCastExpression(constantExpression, returnType); } return constantExpression; } protected override EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType) { // For most ELinq aggregates, the argument type is the return type. For "count", the // argument type is always Boolean, since we project a constant Boolean value in WrapCollectionOperand. TypeUsage booleanTypeUsage = parent._commandTree.CreateTrueExpression().ResultType; return base.FindFunction(parent, call, booleanTypeUsage); } } private sealed class CountTranslator : CountTranslatorBase { internal CountTranslator() : base("COUNT", SequenceMethod.Count, SequenceMethod.CountPredicate) { } } private sealed class LongCountTranslator : CountTranslatorBase { internal LongCountTranslator() : base("BIGCOUNT", SequenceMethod.LongCount, SequenceMethod.LongCountPredicate) { } } private abstract class UnarySequenceMethodTranslator : SequenceMethodTranslator { protected UnarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (null != call.Object) { // instance method Debug.Assert(0 <= call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Object); return TranslateUnary(parent, operand, call); } else { // static extension method Debug.Assert(1 <= call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Arguments[0]); return TranslateUnary(parent, operand, call); } } protected abstract CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call); } private sealed class PassthroughTranslator : UnarySequenceMethodTranslator { internal PassthroughTranslator() : base(SequenceMethod.AsQueryableGeneric, SequenceMethod.AsQueryable, SequenceMethod.AsEnumerable) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // make sure the operand has collection type to avoid treating (for instance) String as a // sub-query if (TypeSemantics.IsCollectionType(operand.ResultType)) { return operand; } else { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedPassthrough( call.Method.Name, operand.ResultType.EdmType.Name)); } } } private sealed class OfTypeTranslator : UnarySequenceMethodTranslator { internal OfTypeTranslator() : base(SequenceMethod.OfType) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { Type clrType = call.Method.GetGenericArguments()[0]; TypeUsage modelType; // If the model type does not exist in the perspective or is not either an EntityType // or a ComplexType, fail - OfType() is not a valid operation on scalars, // enumerations, collections, etc. if (!parent.TryGetValueLayerType(clrType, out modelType) || !(TypeSemantics.IsEntityType(modelType) || TypeSemantics.IsComplexType(modelType))) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_InvalidOfTypeResult(DescribeClrType(clrType))); } // Create an of type expression to filter the original query to include // only those results that are of the specified type. CqtExpression ofTypeExpression = parent.OfType(operand, modelType); return ofTypeExpression; } } private sealed class DistinctTranslator : UnarySequenceMethodTranslator { internal DistinctTranslator() : base(SequenceMethod.Distinct) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { return parent.Distinct(operand); } } private sealed class AnyTranslator : UnarySequenceMethodTranslator { internal AnyTranslator() : base(SequenceMethod.Any) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // "Any" is equivalent to "exists". return parent._commandTree.CreateNotExpression( parent._commandTree.CreateIsEmptyExpression(operand)); } } private abstract class OneLambdaTranslator : SequenceMethodTranslator { internal OneLambdaTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { CqtExpression source; DbExpressionBinding sourceBinding; CqtExpression lambda; return Translate(parent, call, out source, out sourceBinding, out lambda); } // Helper method for tranlsation protected CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda) { Debug.Assert(2 <= call.Arguments.Count); // translate source source = parent.TranslateExpression(call.Arguments[0]); // translate lambda expression LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); lambda = parent.TranslateLambda(lambdaExpression, source, out sourceBinding); return TranslateOneLambda(parent, sourceBinding, lambda); } protected abstract CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda); } private sealed class AnyPredicateTranslator : OneLambdaTranslator { internal AnyPredicateTranslator() : base(SequenceMethod.AnyPredicate) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent._commandTree.CreateAnyExpression(sourceBinding, lambda); } } private sealed class AllTranslator : OneLambdaTranslator { internal AllTranslator() : base(SequenceMethod.All) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent._commandTree.CreateAllExpression(sourceBinding, lambda); } } private sealed class WhereTranslator : OneLambdaTranslator { internal WhereTranslator() : base(SequenceMethod.Where) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Filter(sourceBinding, lambda); } } private sealed class SelectTranslator : OneLambdaTranslator { internal SelectTranslator() : base(SequenceMethod.Select) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { CqtExpression source; DbExpressionBinding sourceBinding; CqtExpression lambda; CqtExpression result = Translate(parent, call, out source, out sourceBinding, out lambda); // If the select if over a GroupBy, check whether an optimized translation can be produced // The default translation is to 'manually' do group by, the optimized is translating into a // DbGroupByExpression CqtExpression rewrittenExpression; if (parent.TryRewrite(source, sourceBinding, lambda, out rewrittenExpression)) { result = rewrittenExpression; } return result; } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Project(sourceBinding, lambda); } } private abstract class FirstTranslatorBase : UnarySequenceMethodTranslator { protected FirstTranslatorBase(bool orDefault, params SequenceMethod[] methods) : base(methods) { _orDefault = orDefault; } private readonly bool _orDefault; protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // Apply a Limit(1) expression to restrict the result to a single element // (this returns an empty set if the input set is initially empty). CqtExpression result = parent.Limit(operand, parent._commandTree.CreateConstantExpression(1)); // If this First() or FirstOrDefault() operation is the root of the query, // then the evaluation is performed in the client over the resulting set, // to provide the same semantics as Linq to Objects. Otherwise, an Element // expression is applied to retrieve the single element (or null, if empty) // from the output set. if (!parent.IsQueryRoot(call)) { if (_orDefault) { result = parent._commandTree.CreateElementExpression(result); result = AddDefaultCase(parent, result, call.Type); } else { // First is not allowed anywhere other than the root of the query, // since First() operations logically executed in the store will not // throw an exception if the input set is empty (typically they will // simply produce a null result). throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); } } // Span is preserved over First/FirstOrDefault with or without a predicate Span inputSpan = null; if (parent.TryGetSpan(operand, out inputSpan)) { parent.AddSpanMapping(result, inputSpan); } return result; } internal static CqtExpression AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType) { // Retrieve default value. object defaultValue = TypeSystem.GetDefaultValue(elementType); if (null == defaultValue) { // Already null, which is the implicit default for DbElementExpression return element; } // Otherwise, use the default value for the type List whenExpressions = new List (1); whenExpressions.Add(parent.CreateIsNullExpression(element, elementType)); List thenExpressions = new List (1); thenExpressions.Add(parent._commandTree.CreateConstantExpression(defaultValue)); DbCaseExpression caseExpression = parent._commandTree.CreateCaseExpression( whenExpressions, thenExpressions, element); return caseExpression; } } private sealed class FirstTranslator : FirstTranslatorBase { internal FirstTranslator() : base(false, SequenceMethod.First) { } } private sealed class FirstOrDefaultTranslator : FirstTranslatorBase { internal FirstOrDefaultTranslator() : base(true, SequenceMethod.FirstOrDefault) { } } private abstract class FirstPredicateTranslatorBase : OneLambdaTranslator { protected FirstPredicateTranslatorBase(bool orDefault, params SequenceMethod[] methods) : base(methods) { _orDefault = orDefault; } private readonly bool _orDefault; internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // Convert the input set and the predicate into a filter expression CqtExpression input = base.Translate(parent, call); // If this First/FirstOrDefault is the root of the query, // then the actual result will be produced by evaluated by // calling First() or FirstOrDefault() on the filtered input set, // which is limited to at most one element by applying a Limit(1) expression. if (parent.IsQueryRoot(call)) { // Calling ExpressionConverter.Limit propagates the Span. return parent.Limit(input, parent._commandTree.CreateConstantExpression(1)); } else { if (_orDefault) { CqtExpression element = parent._commandTree.CreateElementExpression(input); element = FirstTranslatorBase.AddDefaultCase(parent, element, call.Type); // Span is preserved over First/FirstOrDefault with or without a predicate Span inputSpan = null; if (parent.TryGetSpan(input, out inputSpan)) { parent.AddSpanMapping(element, inputSpan); } return element; } else { // First (with or without a predicate) is not allowed anywhere other // than the root of the query, since First(predicate) operations // logically executed in the store will not throw an exception if the // input set is empty (typically they will simply produce a null result). throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); } } } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Filter(sourceBinding, lambda); } } private sealed class FirstPredicateTranslator : FirstPredicateTranslatorBase { internal FirstPredicateTranslator() : base(false, SequenceMethod.FirstPredicate) { } } private sealed class FirstOrDefaultPredicateTranslator : FirstPredicateTranslatorBase { internal FirstOrDefaultPredicateTranslator() : base(true, SequenceMethod.FirstOrDefaultPredicate) { } } private sealed class SelectManyTranslator : OneLambdaTranslator { internal SelectManyTranslator() : base(SequenceMethod.SelectMany, SequenceMethod.SelectManyResultSelector) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { bool hasSelector = 3 == call.Arguments.Count; CqtExpression crossApply = base.Translate(parent, call); // perform a cross apply to implement the core logic for SelectMany (this translates the collection selector): // SelectMany(i, Func> collectionSelector) => // i CROSS APPLY collectionSelector(i) // The cross-apply yields a collection from which we yield either the right hand side (when // no explicit resultSelector is given) or over which we apply the resultSelector Lambda expression. DbExpressionBinding crossApplyBinding = parent._commandTree.CreateExpressionBinding(crossApply); RowType crossApplyRowType = (RowType)(crossApplyBinding.Variable.ResultType.EdmType); CqtExpression projectRight = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[1], crossApplyBinding.Variable); CqtExpression resultProjection; if (hasSelector) { CqtExpression projectLeft = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[0], crossApplyBinding.Variable); LambdaExpression resultSelector = parent.GetLambdaExpression(call, 2); // add the left and right projection terms to the binding context parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[0], projectLeft), new Binding(resultSelector.Parameters[1], projectRight)); // translate the result selector resultProjection = parent.TranslateSet(resultSelector.Body); // pop binding context parent._bindingContext.PopBindingScope(); } else { // project out the right hand side of the apply resultProjection = projectRight; } // wrap result projection in project expression return parent._commandTree.CreateProjectExpression(crossApplyBinding, resultProjection); } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { // elements of the inner selector should be used lambda = parent.NormalizeSetSource(lambda); DbExpressionBinding applyBinding = parent._commandTree.CreateExpressionBinding(lambda); DbApplyExpression crossApply = parent._commandTree.CreateCrossApplyExpression(sourceBinding, applyBinding); return crossApply; } } private sealed class CastMethodTranslator : SequenceMethodTranslator { internal CastMethodTranslator() : base(SequenceMethod.Cast) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // Translate source CqtExpression source = parent.TranslateSet(call.Arguments[0]); // Figure out the type to cast to Type toClrType = TypeSystem.GetElementType(call.Type); Type fromClrType = TypeSystem.GetElementType(call.Arguments[0].Type); // Get binding to the elements of the input source DbExpressionBinding binding = parent._commandTree.CreateExpressionBinding(source); CqtExpression cast = parent.CreateCastExpression(binding.Variable, toClrType, fromClrType); return parent._commandTree.CreateProjectExpression(binding, cast); } } private sealed class GroupByTranslator : SequenceMethodTranslator { internal GroupByTranslator() : base(SequenceMethod.GroupBy, SequenceMethod.GroupByElementSelector, SequenceMethod.GroupByElementSelectorResultSelector, SequenceMethod.GroupByResultSelector) { } // The default translation of GroupBy is: // SELECT d as Key, (SELECT VALUE g FROM source WHERE source.Key = d) as Group // FROM (SELECT DISTINCT source.Key) // // The optimized translation is simply creating a Cqt GroupByExpression internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) { // translate source CqtExpression source = parent.TranslateSet(call.Arguments[0]); // translate key selector LambdaExpression keySelectorLinq = parent.GetLambdaExpression(call, 1); DbExpressionBinding sourceBinding; CqtExpression keySelector = parent.TranslateLambda(keySelectorLinq, source, out sourceBinding); // translate the key selector again in a different binding context (for the nested select) DbExpressionBinding nestedSourceBinding; CqtExpression nestedSelector = parent.TranslateLambda(keySelectorLinq, source, out nestedSourceBinding); // create distinct expression if (!TypeSemantics.IsEqualComparable(keySelector.ResultType)) { // to avoid confusing error message about the "distinct" type, pre-emptively raise an exception // about the group by key selector throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } CqtExpression distinct = parent.Distinct( parent._commandTree.CreateProjectExpression(sourceBinding, keySelector)); DbExpressionBinding distinctBinding = parent._commandTree.CreateExpressionBinding(distinct); // create group projection term DbFilterExpression groupKeyFilter = parent.Filter( nestedSourceBinding, parent.CreateEqualsExpression(nestedSelector, distinctBinding.Variable, EqualsPattern.PositiveNullEquality, keySelectorLinq.Type, keySelectorLinq.Type)); // interpret element selector if needed CqtExpression selection = groupKeyFilter; bool hasElementSelector = sequenceMethod == SequenceMethod.GroupByElementSelector || sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector; if (hasElementSelector) { LambdaExpression elementSelectorLinq = parent.GetLambdaExpression(call, 2); DbExpressionBinding elementSelectorSourceBinding; CqtExpression elementSelector = parent.TranslateLambda(elementSelectorLinq, selection, out elementSelectorSourceBinding); selection = parent._commandTree.CreateProjectExpression(elementSelectorSourceBinding, elementSelector); } // create top level projection List projectionTerms = new List (2); projectionTerms.Add(distinctBinding.Variable); projectionTerms.Add(selection); // build projection type with initializer information List properties = new List (2); properties.Add(new EdmProperty(KeyColumnName, projectionTerms[0].ResultType)); properties.Add(new EdmProperty(GroupColumnName, projectionTerms[1].ResultType)); InitializerMetadata initializerMetadata = InitializerMetadata.CreateGroupingInitializer( parent.EdmItemCollection, TypeSystem.GetElementType(call.Type)); RowType rowType = new RowType(properties, initializerMetadata); TypeUsage rowTypeUsage = TypeUsage.Create(rowType); CqtExpression topLevelProject = parent._commandTree.CreateProjectExpression(distinctBinding, parent._commandTree.CreateNewInstanceExpression(rowTypeUsage, projectionTerms)); if (!hasElementSelector) { //Create optimized translation for the GroupBy - simple GroupBy template DbGroupExpressionBinding groupByBinding; CqtExpression newKeySelector = parent.TranslateLambda(keySelectorLinq, source, out groupByBinding); DbGroupByTemplate groupByTemplate = new DbGroupByTemplate(groupByBinding); groupByTemplate.GroupKeys.Add(new KeyValuePair (KeyColumnName, newKeySelector)); parent._groupByDefaultToOptimizedTranslationMap.Add(topLevelProject, groupByTemplate); } var result = topLevelProject; // GroupBy may include a result selector; handle it result = ProcessResultSelector(parent, call, sequenceMethod, topLevelProject, result); return result; } private static DbExpression ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result) { // interpret result selector if needed LambdaExpression resultSelectorLinqExpression = null; if (sequenceMethod == SequenceMethod.GroupByResultSelector) { resultSelectorLinqExpression = parent.GetLambdaExpression(call, 2); } else if (sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector) { resultSelectorLinqExpression = parent.GetLambdaExpression(call, 3); } if (null != resultSelectorLinqExpression) { // selector maps (Key, Group) -> Result // push bindings for key and group DbExpressionBinding topLevelProjectBinding = parent._commandTree.CreateExpressionBinding(topLevelProject); parent._variableNameToInputExpression.Add(topLevelProjectBinding.VariableName, topLevelProject); DbPropertyExpression keyExpression = parent._commandTree.CreatePropertyExpression( KeyColumnName, topLevelProjectBinding.Variable); DbPropertyExpression groupExpression = parent._commandTree.CreatePropertyExpression( GroupColumnName, topLevelProjectBinding.Variable); parent._bindingContext.PushBindingScope( new Binding(resultSelectorLinqExpression.Parameters[0], keyExpression), new Binding(resultSelectorLinqExpression.Parameters[1], groupExpression)); // translate selector CqtExpression resultSelector = parent.TranslateExpression( resultSelectorLinqExpression.Body); result = parent._commandTree.CreateProjectExpression(topLevelProjectBinding, resultSelector); // see if the selector can be optimized CqtExpression rewrittenExpression; if (parent.TryRewrite(topLevelProject, topLevelProjectBinding, resultSelector, out rewrittenExpression)) { result = rewrittenExpression; } parent._bindingContext.PopBindingScope(); } return result; } internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Fail("unreachable code"); return null; } } private sealed class GroupJoinTranslator : SequenceMethodTranslator { internal GroupJoinTranslator() : base(SequenceMethod.GroupJoin) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // o.GroupJoin(i, ok => outerKeySelector, ik => innerKeySelector, (o, i) => projection) // --> // SELECT projection(o, i) // FROM ( // SELECT o, (SELECT i FROM i WHERE o.outerKeySelector = i.innerKeySelector) as i // FROM o) // translate inputs CqtExpression outer = parent.TranslateSet(call.Arguments[0]); CqtExpression inner = parent.TranslateSet(call.Arguments[1]); // translate key selectors DbExpressionBinding outerBinding; DbExpressionBinding innerBinding; LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); CqtExpression outerSelector = parent.TranslateLambda( outerLambda, outer, out outerBinding); CqtExpression innerSelector = parent.TranslateLambda( innerLambda, inner, out innerBinding); // create innermost SELECT i FROM i WHERE ... if (!TypeSemantics.IsEqualComparable(outerSelector.ResultType) || !TypeSemantics.IsEqualComparable(innerSelector.ResultType)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } CqtExpression nestedCollection = parent.Filter(innerBinding, parent.CreateEqualsExpression(outerSelector, innerSelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type)); // create "join" SELECT o, (nestedCollection) const string outerColumn = "o"; const string innerColumn = "i"; List > recordColumns = new List >(2); recordColumns.Add(new KeyValuePair (outerColumn, outerBinding.Variable)); recordColumns.Add(new KeyValuePair (innerColumn, nestedCollection)); CqtExpression joinProjection = parent._commandTree.CreateNewRowExpression(recordColumns); CqtExpression joinProject = parent._commandTree.CreateProjectExpression(outerBinding, joinProjection); DbExpressionBinding joinProjectBinding = parent._commandTree.CreateExpressionBinding(joinProject); // create property expressions for the outer and inner terms to bind to the parameters to the // group join selector CqtExpression outerProperty = parent._commandTree.CreatePropertyExpression(outerColumn, joinProjectBinding.Variable); CqtExpression innerProperty = parent._commandTree.CreatePropertyExpression(innerColumn, joinProjectBinding.Variable); // push the inner and the outer terms into the binding scope LambdaExpression linqSelector = parent.GetLambdaExpression(call, 4); parent._bindingContext.PushBindingScope( new Binding(linqSelector.Parameters[0], outerProperty), new Binding(linqSelector.Parameters[1], innerProperty)); // translate the selector CqtExpression selectorProject = parent.TranslateExpression(linqSelector.Body); // pop the binding scope parent._bindingContext.PopBindingScope(); // create the selector projection CqtExpression selector = parent._commandTree.CreateProjectExpression(joinProjectBinding, selectorProject); return selector; } } private abstract class OrderByTranslatorBase : OneLambdaTranslator { private readonly bool _ascending; protected OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods) : base(methods) { _ascending = ascending; } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { List keys = new List (1); DbSortClause sortSpec = parent._commandTree.CreateSortClause(lambda, _ascending); keys.Add(sortSpec); DbSortExpression sort = parent.Sort(sourceBinding, keys); return sort; } } private sealed class OrderByTranslator : OrderByTranslatorBase { internal OrderByTranslator() : base(true, SequenceMethod.OrderBy) { } } private sealed class OrderByDescendingTranslator : OrderByTranslatorBase { internal OrderByDescendingTranslator() : base(false, SequenceMethod.OrderByDescending) { } } // Note: because we need to "push-down" the expression binding for ThenBy, this class // does not inherit from OneLambdaTranslator, although it is similar. private abstract class ThenByTranslatorBase : SequenceMethodTranslator { private readonly bool _ascending; protected ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods) : base(methods) { _ascending = ascending; } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(2 == call.Arguments.Count); CqtExpression source = parent.TranslateSet(call.Arguments[0]); if (DbExpressionKind.Sort != source.ExpressionKind) { throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.ELinq_ThenByDoesNotFollowOrderBy); } DbSortExpression sortExpression = (DbSortExpression)source; // retrieve information about existing sort DbExpressionBinding binding = sortExpression.Input; // get information on new sort term LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); ParameterExpression parameter = lambdaExpression.Parameters[0]; // push-down the binding scope information and translate the new sort key parent._bindingContext.PushBindingScope(new Binding(parameter, binding.Variable)); CqtExpression lambda = parent.TranslateExpression(lambdaExpression.Body); parent._bindingContext.PopBindingScope(); // create a new sort expression List keys = new List (sortExpression.SortOrder); keys.Add(new DbSortClause(lambda, _ascending, null)); sortExpression = parent.Sort(binding, keys); return sortExpression; } } private sealed class ThenByTranslator : ThenByTranslatorBase { internal ThenByTranslator() : base(true, SequenceMethod.ThenBy) { } } private sealed class ThenByDescendingTranslator : ThenByTranslatorBase { internal ThenByDescendingTranslator() : base(false, SequenceMethod.ThenByDescending) { } } #endregion } } } // File provided for Reference Use Only by Microsoft Corporation (c) 2007. //---------------------------------------------------------------------- // // Copyright (c) Microsoft Corporation. All rights reserved. // // // @owner [....], [....] //--------------------------------------------------------------------- using System.Data.Common.CommandTrees; using System.Collections.Generic; using CqtExpression = System.Data.Common.CommandTrees.DbExpression; using LinqExpression = System.Linq.Expressions.Expression; using System.Diagnostics; using System.Data.Metadata.Edm; using System.Linq.Expressions; using System.Reflection; using System.Linq; using System.Data.Entity; using System.Data.Common; using System.Globalization; namespace System.Data.Objects.ELinq { internal sealed partial class ExpressionConverter { ////// Translates System.Linq.Expression.MethodCallExpression to System.Data.Common.CommandTrees.DbExpression /// private sealed class MethodCallTranslator : TypedTranslator{ internal MethodCallTranslator() : base(ExpressionType.Call) { } protected override CqtExpression TypedTranslate(ExpressionConverter parent, MethodCallExpression linq) { // check if this is a known sequence method SequenceMethod sequenceMethod; SequenceMethodTranslator sequenceTranslator; if (ReflectionUtil.TryIdentifySequenceMethod(linq.Method, out sequenceMethod) && s_sequenceTranslators.TryGetValue(sequenceMethod, out sequenceTranslator)) { return sequenceTranslator.Translate(parent, linq, sequenceMethod); } // check if this is a known method CallTranslator callTranslator; if (TryGetCallTranslator(linq.Method, out callTranslator)) { return callTranslator.Translate(parent, linq); } // check if this is an ObjectQuery<> builder method Type declaringType = linq.Method.DeclaringType; if (linq.Method.IsPublic && null != declaringType && declaringType.IsGenericType && typeof(ObjectQuery<>) == declaringType.GetGenericTypeDefinition()) { ObjectQueryCallTranslator builderTranslator; if (s_objectQueryTranslators.TryGetValue(linq.Method.Name, out builderTranslator)) { return builderTranslator.Translate(parent, linq); } } // fall back on the default translator return s_defaultTranslator.Translate(parent, linq); } #region Static members and initializers private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; // initialize fall-back translator private static readonly CallTranslator s_defaultTranslator = new DefaultTranslator(); private static readonly Dictionary s_methodTranslators = InitializeMethodTranslators(); private static readonly Dictionary s_sequenceTranslators = InitializeSequenceMethodTranslators(); private static readonly Dictionary s_objectQueryTranslators = InitializeObjectQueryTranslators(); private static bool s_vbMethodsInitialized; private static readonly object s_vbInitializerLock = new object(); private static Dictionary InitializeMethodTranslators() { // initialize translators for specific methods (e.g., Int32.op_Equality) Dictionary methodTranslators = new Dictionary (); foreach (CallTranslator translator in GetCallTranslators()) { foreach (MethodInfo method in translator.Methods) { methodTranslators.Add(method, translator); } } return methodTranslators; } private static Dictionary InitializeSequenceMethodTranslators() { // initialize translators for sequence methods (e.g., Sequence.Select) Dictionary sequenceTranslators = new Dictionary (); foreach (SequenceMethodTranslator translator in GetSequenceMethodTranslators()) { foreach (SequenceMethod method in translator.Methods) { sequenceTranslators.Add(method, translator); } } return sequenceTranslators; } private static Dictionary InitializeObjectQueryTranslators() { // initialize translators for object query methods (e.g. ObjectQuery .OfType (), ObjectQuery.Include(string) ) Dictionary objectQueryCallTranslators = new Dictionary (StringComparer.Ordinal); foreach (ObjectQueryCallTranslator translator in GetObjectQueryCallTranslators()) { objectQueryCallTranslators[translator.MethodName] = translator; } return objectQueryCallTranslators; } /// /// Tries to get a translator for the given method info. /// If the given method info corresponds to a Visual Basic property, /// it also initializes the Visual Basic translators if they have not been initialized /// /// /// ///private static bool TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator) { if (s_methodTranslators.TryGetValue(methodInfo, out callTranslator)) { return true; } // check if this is the visual basic assembly if (s_visualBasicAssemblyFullName == methodInfo.DeclaringType.Assembly.FullName) { lock (s_vbInitializerLock) { if (!s_vbMethodsInitialized) { InitializeVBMethods(methodInfo.DeclaringType.Assembly); s_vbMethodsInitialized = true; } // try again return s_methodTranslators.TryGetValue(methodInfo, out callTranslator); } } callTranslator = null; return false; } private static void InitializeVBMethods(Assembly vbAssembly) { Debug.Assert(!s_vbMethodsInitialized); foreach (CallTranslator translator in GetVisualBasicCallTranslators(vbAssembly)) { foreach (MethodInfo method in translator.Methods) { s_methodTranslators.Add(method, translator); } } } private static IEnumerable GetVisualBasicCallTranslators(Assembly vbAssembly) { yield return new VBCanonicalFunctionDefaultTranslator(vbAssembly); yield return new VBCanonicalFunctionRenameTranslator(vbAssembly); yield return new VBDatePartTranslator(vbAssembly); } private static IEnumerable GetCallTranslators() { yield return new CanonicalFunctionDefaultTranslator(); yield return new ContainsTranslator(); yield return new StartsWithTranslator(); yield return new EndsWithTranslator(); yield return new IndexOfTranslator(); yield return new SubstringTranslator(); yield return new RemoveTranslator(); yield return new InsertTranslator(); yield return new IsNullOrEmptyTranslator(); yield return new StringConcatTranslator(); yield return new TrimStartTranslator(); yield return new TrimEndTranslator(); } private static IEnumerable GetSequenceMethodTranslators() { yield return new ConcatTranslator(); yield return new UnionTranslator(); yield return new IntersectTranslator(); yield return new ExceptTranslator(); yield return new DistinctTranslator(); yield return new WhereTranslator(); yield return new SelectTranslator(); yield return new OrderByTranslator(); yield return new OrderByDescendingTranslator(); yield return new ThenByTranslator(); yield return new ThenByDescendingTranslator(); yield return new SelectManyTranslator(); yield return new AnyTranslator(); yield return new AnyPredicateTranslator(); yield return new AllTranslator(); yield return new JoinTranslator(); yield return new GroupByTranslator(); yield return new MaxTranslator(); yield return new MinTranslator(); yield return new AverageTranslator(); yield return new SumTranslator(); yield return new CountTranslator(); yield return new LongCountTranslator(); yield return new CastMethodTranslator(); yield return new GroupJoinTranslator(); yield return new OfTypeTranslator(); yield return new SingleTranslatorNotSupported(); yield return new PassthroughTranslator(); yield return new FirstTranslator(); yield return new FirstPredicateTranslator(); yield return new FirstOrDefaultTranslator(); yield return new FirstOrDefaultPredicateTranslator(); yield return new TakeTranslator(); yield return new SkipTranslator(); } private static IEnumerable GetObjectQueryCallTranslators() { yield return new ObjectQueryBuilderDistinctTranslator(); yield return new ObjectQueryBuilderExceptTranslator(); yield return new ObjectQueryBuilderFirstTranslator(); yield return new ObjectQueryIncludeTranslator(); yield return new ObjectQueryBuilderIntersectTranslator(); yield return new ObjectQueryBuilderOfTypeTranslator(); yield return new ObjectQueryBuilderUnionTranslator(); } #endregion #region Method translators private abstract class CallTranslator { private readonly IEnumerable _methods; protected CallTranslator(params MethodInfo[] methods) { _methods = methods; } protected CallTranslator(IEnumerable methods) { _methods = methods; } internal IEnumerable Methods { get { return _methods; } } internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); public override string ToString() { return GetType().Name; } } private abstract class ObjectQueryCallTranslator : CallTranslator { private readonly string _methodName; protected ObjectQueryCallTranslator(string methodName) { _methodName = methodName; } internal string MethodName { get { return _methodName; } } } private abstract class ObjectQueryBuilderCallTranslator : ObjectQueryCallTranslator { private readonly SequenceMethodTranslator _translator; protected ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent) : base(methodName) { bool translatorFound = s_sequenceTranslators.TryGetValue(sequenceEquivalent, out _translator); Debug.Assert(translatorFound, "Translator not found for " + sequenceEquivalent.ToString()); } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return _translator.Translate(parent, call); } } private sealed class ObjectQueryBuilderUnionTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderUnionTranslator() : base("Union", SequenceMethod.Union) { } } private sealed class ObjectQueryBuilderIntersectTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderIntersectTranslator() : base("Intersect", SequenceMethod.Intersect) { } } private sealed class ObjectQueryBuilderExceptTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderExceptTranslator() : base("Except", SequenceMethod.Except) { } } private sealed class ObjectQueryBuilderDistinctTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderDistinctTranslator() : base("Distinct", SequenceMethod.Distinct) { } } private sealed class ObjectQueryBuilderOfTypeTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderOfTypeTranslator() : base("OfType", SequenceMethod.OfType) { } } private sealed class ObjectQueryBuilderFirstTranslator : ObjectQueryBuilderCallTranslator { internal ObjectQueryBuilderFirstTranslator() : base("First", SequenceMethod.First) { } } private sealed class ObjectQueryIncludeTranslator : ObjectQueryCallTranslator { internal ObjectQueryIncludeTranslator() : base("Include") { } internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(string)), "Invalid Include arguments?"); CqtExpression queryExpression = parent.TranslateExpression(call.Object); Span span; if (!parent.TryGetSpan(queryExpression, out span)) { span = null; } CqtExpression arg = parent.TranslateExpression(call.Arguments[0]); string includePath = null; if (arg.ExpressionKind == DbExpressionKind.Constant) { includePath = (string)((DbConstantExpression)arg).Value; } else { // The 'Include' method implementation on ELinqQueryState creates // a method call expression with a string constant argument taking // the value of the string argument passed to ObjectQuery.Include, // and so this is the only supported pattern here. throw EntityUtil.NotSupported(Entity.Strings.ELinq_UnsupportedInclude); } return parent.AddSpanMapping(queryExpression, Span.IncludeIn(span, includePath)); } } private sealed class DefaultTranslator : CallTranslator { internal DefaultTranslator() : base() { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { MethodInfo suggestedMethodInfo; if (TryGetAlternativeMethod(call.Method, out suggestedMethodInfo)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethodSuggestedAlternative(call.Method, suggestedMethodInfo)); } //The default error message throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethod(call.Method)); } #region Static Members private static readonly Dictionary s_alternativeMethods = InitializeAlternateMethodInfos(); private static bool s_vbMethodsInitialized; private static readonly object s_vbInitializerLock = new object(); /// /// Tries to check whether there is an alternative method suggested insted of the given unsupported one. /// /// /// ///private static bool TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo) { if (s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo)) { return true; } // check if this is the visual basic assembly if (s_visualBasicAssemblyFullName == originalMethodInfo.DeclaringType.Assembly.FullName) { lock (s_vbInitializerLock) { if (!s_vbMethodsInitialized) { InitializeVBMethods(originalMethodInfo.DeclaringType.Assembly); s_vbMethodsInitialized = true; } // try again return s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo); } } suggestedMethodInfo = null; return false; } /// /// Initializes the dictionary of alternative methods. /// Currently, it simply initializes an empty dictionary. /// ///private static Dictionary InitializeAlternateMethodInfos() { return new Dictionary (1); } /// /// Populates the dictionary of alternative methods with the VB methods /// /// private static void InitializeVBMethods(Assembly vbAssembly) { Debug.Assert(!s_vbMethodsInitialized); //Handle { Mid(arg1, ar2), Mid(arg1, arg2, arg3) } Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); s_alternativeMethods.Add( stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null), stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int), typeof(int) }, null)); } #endregion } private sealed class CanonicalFunctionDefaultTranslator : CallTranslator { internal CanonicalFunctionDefaultTranslator() : base(GetMethods()) { } private static IEnumerableGetMethods() { //Math functions yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null); //Decimal functions yield return typeof(Decimal).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Decimal).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); yield return typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null); //String functions yield return typeof(String).GetMethod("Replace", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(String), typeof(String) }, null); yield return typeof(String).GetMethod("ToLower", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); yield return typeof(String).GetMethod("ToUpper", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); yield return typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null); } // Default translator for method calls into canonical functions. // Translation: // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) // this.MethodName(arg1, arg2, .., argn) -> MethodName(this, arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { LinqExpression[] linqArguments; if (!call.Method.IsStatic) { Debug.Assert(call.Object != null, "Instance method without this"); List arguments = new List (call.Arguments.Count + 1); arguments.Add(call.Object); arguments.AddRange(call.Arguments); linqArguments = arguments.ToArray(); } else { linqArguments = call.Arguments.ToArray(); } return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, linqArguments); } } #region System.String Method Translators private sealed class ContainsTranslator : CallTranslator { internal ContainsTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Contains", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.Contains(argument) -> IndexOf(argument, object) > 0 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); DbComparisonExpression comparisonExpression = parent._commandTree.CreateGreaterThanExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(0)); return comparisonExpression; } } private sealed class IndexOfTranslator : CallTranslator { internal IndexOfTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("IndexOf", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // IndexOf(arg1) -> IndexOf(arg1, this) - 1 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IndexOf"); DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); CqtExpression minusExpression = parent._commandTree.CreateMinusExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(1)); return minusExpression; } } private sealed class StartsWithTranslator : CallTranslator { internal StartsWithTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("StartsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.StartsWith(argument) -> IndexOf(argument, object) == 1 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object); DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(indexOfExpression, parent._commandTree.CreateConstantExpression(1)); return comparisonExpression; } } private sealed class EndsWithTranslator : CallTranslator { internal EndsWithTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("EndsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null); } // Translation: // object.EndsWith(argument) -> Right(object, Length(argument)) == argument internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { DbFunctionExpression lengthExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]); DbExpression rightExpression = parent.CreateCanonicalFunction(ExpressionConverter.Right, call, parent.TranslateExpression(call.Object), lengthExpression); DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression( rightExpression, parent.TranslateExpression(call.Arguments[0])); return comparisonExpression; } } private sealed class SubstringTranslator : CallTranslator { internal SubstringTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); } // Translation: // Substring(arg1) -> Substring(this, arg1+1, Length(this) - arg1)) // Substring(arg1, arg2) -> Substring(this, arg1+1, arg2) // internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Substring"); CqtExpression target = parent.TranslateExpression(call.Object); CqtExpression fromIndex = parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent._commandTree.CreateConstantExpression(1)); CqtExpression length; if (call.Arguments.Count == 1) { length = parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent.TranslateExpression(call.Arguments[0])); } else { length = parent.TranslateExpression(call.Arguments[1]); } CqtExpression substringExpression = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, target, fromIndex, length); return substringExpression; } } private sealed class RemoveTranslator : CallTranslator { internal RemoveTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null); yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null); } // Translation: // Remove(arg1) -> Substring(this, 1, arg1) // Remove(arg1, arg2) -> Concat(Substring(this, 1, arg1) , Substring(this, arg1 + arg2 + 1, Length(this) - (arg1 + arg2))) // Remove(arg1, arg2) is only supported if arg2 is a non-negative integer internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Remove"); //Substring(this, 1, arg1) CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreateConstantExpression(1), parent.TranslateExpression(call.Arguments[0])); //Concat(result, Substring(this, (arg1 + arg2) +1, Length(this) - (arg1 + arg2))) if (call.Arguments.Count == 2) { //If there are two arguemtns, we only support cases when the second one translates to a non-negative constant CqtExpression translatedArgument1 = parent.TranslateExpression(call.Arguments[1]); if (!IsNonNegativeIntegerConstant(translatedArgument1)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedStringRemoveCase(call.Method, call.Method.GetParameters()[1].Name)); } // Build the second substring // (arg1 + arg2) +1 CqtExpression substringStartIndex = parent._commandTree.CreatePlusExpression( parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), translatedArgument1), parent._commandTree.CreateConstantExpression(1)); // Length(this) - (arg1 + arg2) CqtExpression substringLength = parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent.TranslateExpression(call.Arguments[1]))); // Substring(this, substringStartIndex, substringLenght) CqtExpression secondSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), substringStartIndex, substringLength); // result = Concat (result, secondSubstring) result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, secondSubstring); } return result; } private static bool IsNonNegativeIntegerConstant(CqtExpression argument) { // Check whether it is a constant of type Int32 if (argument.ExpressionKind != DbExpressionKind.Constant || !TypeSemantics.IsPrimitiveType(argument.ResultType, PrimitiveTypeKind.Int32)) { return false; } // Check whether its value is non-negative DbConstantExpression constantExpression = (DbConstantExpression)argument; int value = (int)constantExpression.Value; if (value < 0) { return false; } return true; } } private sealed class InsertTranslator : CallTranslator { internal InsertTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Insert", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(string) }, null); } // Translation: // Insert(startIndex, value) -> Concat(Concat(Substring(this, 1, startIndex), value), Substring(this, startIndex+1, Length(this) - startIndex)) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 2, "Expecting 2 arguments for String.Insert"); //Substring(this, 1, startIndex) CqtExpression firstSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreateConstantExpression(1), parent.TranslateExpression(call.Arguments[0])); //Substring(this, startIndex+1, Length(this) - startIndex) CqtExpression secondSubstring = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, parent.TranslateExpression(call.Object), parent._commandTree.CreatePlusExpression( parent.TranslateExpression(call.Arguments[0]), parent._commandTree.CreateConstantExpression(1)), parent._commandTree.CreateMinusExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object), parent.TranslateExpression(call.Arguments[0]))); // result = Concat( Concat (firstSubstring, value), secondSubstring ) CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, firstSubstring, parent.TranslateExpression(call.Arguments[1])), secondSubstring); return result; } } private sealed class IsNullOrEmptyTranslator : CallTranslator { internal IsNullOrEmptyTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("IsNullOrEmpty", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); } // Translation: // IsNullOrEmpty(value) -> (IsNull(value)) OR Length(value) = 0 internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IsNullOrEmpty"); //IsNull(value) CqtExpression isNullExpression = parent._commandTree.CreateIsNullExpression( parent.TranslateExpression(call.Arguments[0])); //Length(value) = 0 CqtExpression emptyStringExpression = parent._commandTree.CreateEqualsExpression( parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]), parent._commandTree.CreateConstantExpression(0)); CqtExpression result = parent._commandTree.CreateOrExpression(isNullExpression, emptyStringExpression); return result; } } private sealed class StringConcatTranslator : CallTranslator { internal StringConcatTranslator() : base(GetMethods()) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string) }, null); yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string) }, null); yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null); } // Translation: // Concat (arg1, arg2) -> Concat(arg1, arg2) // Concat (arg1, arg2, arg3) -> Concat(Concat(arg1, arg2), arg3) // Concat (arg1, arg2, arg3, arg4) -> Concat(Concat(Concat(arg1, arg2), arg3), arg4) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count >= 2 && call.Arguments.Count <= 4, "Expecting between 2 and 4 arguments for String.Concat"); CqtExpression result = parent.TranslateExpression(call.Arguments[0]); for (int argIndex = 1; argIndex < call.Arguments.Count; argIndex++) { // result = Concat(result, arg[argIndex]) result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, parent.TranslateExpression(call.Arguments[argIndex])); } return result; } } private abstract class TrimStartTrimEndBaseTranslator : CallTranslator { private string _canonicalFunctionName; protected TrimStartTrimEndBaseTranslator(IEnumerable methods, string canonicalFunctionName) : base(methods) { _canonicalFunctionName = canonicalFunctionName; } // Translation: // object.MethodName -> CanonicalFunctionName(object) // Supported only if the argument is an empty array. internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (!IsEmptyArray(call.Arguments[0])) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedTrimStartTrimEndCase(call.Method)); } return parent.TranslateIntoCanonicalFunction(_canonicalFunctionName, call, call.Object); } internal static bool IsEmptyArray(LinqExpression expression) { if (expression.NodeType != ExpressionType.NewArrayInit) { return false; } NewArrayExpression newArray = (NewArrayExpression)expression; if (newArray.Expressions.Count != 0) { return false; } return true; } } private sealed class TrimStartTranslator : TrimStartTrimEndBaseTranslator { internal TrimStartTranslator() : base(GetMethods(), ExpressionConverter.LTrim) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("TrimStart", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); } } private sealed class TrimEndTranslator : TrimStartTrimEndBaseTranslator { internal TrimEndTranslator() : base(GetMethods(), ExpressionConverter.RTrim) { } private static IEnumerable GetMethods() { yield return typeof(String).GetMethod("TrimEnd", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null); } } #endregion #region Visual Basic Specific Translators private sealed class VBCanonicalFunctionDefaultTranslator : CallTranslator { private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; internal VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } private static IEnumerable GetMethods(Assembly vbAssembly) { //Strings Types Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); yield return stringsType.GetMethod("Trim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("LTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("RTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null); yield return stringsType.GetMethod("Left", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); yield return stringsType.GetMethod("Right", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null); //DateTimeType Type dateTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); yield return dateTimeType.GetMethod("Year", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Month", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Day", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Hour", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Minute", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); yield return dateTimeType.GetMethod("Second", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null); } // Default translator for vb static method calls into canonical functions. // Translation: // MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, call.Arguments.ToArray()); } } private sealed class VBCanonicalFunctionRenameTranslator : CallTranslator { private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings"; private static readonly Dictionary s_methodNameMap = new Dictionary (4); internal VBCanonicalFunctionRenameTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } private static IEnumerable GetMethods(Assembly vbAssembly) { //Strings Types Type stringsType = vbAssembly.GetType(s_stringsTypeFullName); yield return GetMethod(stringsType, "Len", ExpressionConverter.Length, new Type[] { typeof(string) }); yield return GetMethod(stringsType, "Mid", ExpressionConverter.Substring, new Type[] { typeof(string), typeof(int), typeof(int) }); yield return GetMethod(stringsType, "UCase", ExpressionConverter.ToUpper, new Type[] { typeof(string) }); yield return GetMethod(stringsType, "LCase", ExpressionConverter.ToLower, new Type[] { typeof(string) }); } private static MethodInfo GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes) { MethodInfo methodInfo = declaringType.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static, null, argumentTypes, null); s_methodNameMap.Add(methodInfo, canonicalFunctionName); return methodInfo; } // Translator for static method calls into canonical functions when only the name of the canonical function // is different from the name of the method, but the argumens match. // Translation: // MethodName(arg1, arg2, .., argn) -> CanonicalFunctionName(arg1, arg2, .., argn) internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { return parent.TranslateIntoCanonicalFunction(s_methodNameMap[call.Method], call, call.Arguments.ToArray()); } } private sealed class VBDatePartTranslator : CallTranslator { private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime"; private const string s_DateIntervalFullName = "Microsoft.VisualBasic.DateInterval"; private const string s_FirstDayOfWeekFullName = "Microsoft.VisualBasic.FirstDayOfWeek"; private const string s_FirstWeekOfYearFullName = "Microsoft.VisualBasic.FirstWeekOfYear"; private static HashSet s_supportedIntervals; internal VBDatePartTranslator(Assembly vbAssembly) : base(GetMethods(vbAssembly)) { } static VBDatePartTranslator() { s_supportedIntervals = new HashSet (); s_supportedIntervals.Add(ExpressionConverter.Year); s_supportedIntervals.Add(ExpressionConverter.Month); s_supportedIntervals.Add(ExpressionConverter.Day); s_supportedIntervals.Add(ExpressionConverter.Hour); s_supportedIntervals.Add(ExpressionConverter.Minute); s_supportedIntervals.Add(ExpressionConverter.Second); } private static IEnumerable GetMethods(Assembly vbAssembly) { Type dateAndTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName); Type dateIntervalEnum = vbAssembly.GetType(s_DateIntervalFullName); Type firstDayOfWeekEnum = vbAssembly.GetType(s_FirstDayOfWeekFullName); Type firstWeekOfYearEnum = vbAssembly.GetType(s_FirstWeekOfYearFullName); yield return dateAndTimeType.GetMethod("DatePart", BindingFlags.Public | BindingFlags.Static, null, new Type[] { dateIntervalEnum, typeof(DateTime), firstDayOfWeekEnum, firstWeekOfYearEnum }, null); } // Translation: // DatePart(DateInterval, date, arg3, arg4) -> 'DateInterval'(date) // Note: it is only supported for the values of DateInterval listed in s_supportedIntervals. internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(call.Arguments.Count == 4, "Expecting 4 arguments for Microsoft.VisualBasic.DateAndTime.DatePart"); ConstantExpression intervalLinqExpression = call.Arguments[0] as ConstantExpression; if (intervalLinqExpression == null) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartNonConstantInterval(call.Method, call.Method.GetParameters()[0].Name)); } string intervalValue = intervalLinqExpression.Value.ToString(); if (!s_supportedIntervals.Contains(intervalValue)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartInvalidInterval(call.Method, call.Method.GetParameters()[0].Name, intervalValue)); } CqtExpression result = parent.TranslateIntoCanonicalFunction(intervalValue, call, call.Arguments[1]); return result; } } #endregion #endregion #region Sequence method translators private abstract class SequenceMethodTranslator { private readonly IEnumerable _methods; protected SequenceMethodTranslator(params SequenceMethod[] methods) { _methods = methods; } internal IEnumerable Methods { get { return _methods; } } internal virtual CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) { return Translate(parent, call); } internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call); public override string ToString() { return GetType().Name; } } private abstract class PagingTranslator : UnarySequenceMethodTranslator { protected PagingTranslator(params SequenceMethod[] methods) : base(methods) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // translate count expression Debug.Assert(call.Arguments.Count == 2, "Skip and Take must have 2 arguments"); LinqExpression linqCount = call.Arguments[1]; CqtExpression count = parent.TranslateExpression(linqCount); // remove projections at the root of the expression and then reapply after apply paging operator DbProjectExpression projection = null; if (operand.ExpressionKind == DbExpressionKind.Project) { projection = (DbProjectExpression)operand; operand = projection.Input.Expression; } // translate paging expression DbExpression result = TranslatePagingOperator(parent, operand, count); // reapply project as necessary if (null != projection) { projection.Input.Expression = result; result = projection; } return result; } protected abstract CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count); } private sealed class TakeTranslator : PagingTranslator { internal TakeTranslator() : base(SequenceMethod.Take) { } protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) { return parent.Limit(operand, count); } } private sealed class SkipTranslator : PagingTranslator { internal SkipTranslator() : base(SequenceMethod.Skip) { } protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count) { // skip requires a sorted input if (operand.ExpressionKind != DbExpressionKind.Sort) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_SkipWithoutOrder); } DbSortExpression sortedOperand = (DbSortExpression)operand; Span sortedSpan = null; bool hadSpan = parent.TryGetSpan(sortedOperand, out sortedSpan); // generate a skip statement with the sort order of the original sorted input. DbSkipExpression skip = parent.Skip(sortedOperand.Input, sortedOperand.SortOrder, count); // If the original DbSortExpression had Span information, then this is applied // to the newly created DbSkipExpression before returning. if (hadSpan) { parent.AddSpanMapping(skip, sortedSpan); } return skip; } } private sealed class SingleTranslatorNotSupported : SequenceMethodTranslator { internal SingleTranslatorNotSupported() : base(SequenceMethod.Single, SequenceMethod.SinglePredicate, SequenceMethod.SingleOrDefault, SequenceMethod.SingleOrDefaultPredicate) { } internal override System.Data.Common.CommandTrees.DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedSingle); } } private sealed class JoinTranslator : SequenceMethodTranslator { internal JoinTranslator() : base(SequenceMethod.Join) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(5 == call.Arguments.Count); // get expressions describing inputs to the join CqtExpression outer = parent.TranslateSet(call.Arguments[0]); CqtExpression inner = parent.TranslateSet(call.Arguments[1]); // get expressions describing key selectors LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); // translator key selectors DbExpressionBinding outerBinding; DbExpressionBinding innerBinding; CqtExpression outerKeySelector = parent.TranslateLambda(outerLambda, outer, out outerBinding); CqtExpression innerKeySelector = parent.TranslateLambda(innerLambda, inner, out innerBinding); // construct join expression if (!TypeSemantics.IsEqualComparable(outerKeySelector.ResultType) || !TypeSemantics.IsEqualComparable(innerKeySelector.ResultType)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } DbJoinExpression join = parent._commandTree.CreateInnerJoinExpression( outerBinding, innerBinding, parent.CreateEqualsExpression(outerKeySelector, innerKeySelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type)); DbExpressionBinding joinBinding = parent._commandTree.CreateExpressionBinding(join); // get selector expression LambdaExpression selectorLambda = parent.GetLambdaExpression(call, 4); // create property expressions for the inner and outer DbPropertyExpression joinOuter = parent._commandTree.CreatePropertyExpression( outerBinding.VariableName, joinBinding.Variable); DbPropertyExpression joinInner = parent._commandTree.CreatePropertyExpression( innerBinding.VariableName, joinBinding.Variable); // push outer and inner join parts into the binding scope (the order // is irrelevant because the binding context matches based on parameter // reference rather than ordinal) parent._bindingContext.PushBindingScope( new Binding(selectorLambda.Parameters[0], joinOuter), new Binding(selectorLambda.Parameters[1], joinInner)); // translate join selector CqtExpression selector = parent.TranslateExpression(selectorLambda.Body); // pop binding scope parent._bindingContext.PopBindingScope(); return parent._commandTree.CreateProjectExpression(joinBinding, selector); } } private abstract class BinarySequenceMethodTranslator : SequenceMethodTranslator { protected BinarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (null != call.Object) { // instance method Debug.Assert(1 == call.Arguments.Count); CqtExpression left = parent.TranslateSet(call.Object); CqtExpression right = parent.TranslateSet(call.Arguments[0]); return TranslateBinary(parent, left, right); } else { // static extension method Debug.Assert(2 == call.Arguments.Count); CqtExpression left = parent.TranslateSet(call.Arguments[0]); CqtExpression right = parent.TranslateSet(call.Arguments[1]); return TranslateBinary(parent, left, right); } } protected abstract CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right); } private class ConcatTranslator : BinarySequenceMethodTranslator { internal ConcatTranslator() : base(SequenceMethod.Concat) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.UnionAll(left, right); } } private sealed class UnionTranslator : BinarySequenceMethodTranslator { internal UnionTranslator() : base(SequenceMethod.Union) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Distinct(parent.UnionAll(left, right)); } } private sealed class IntersectTranslator : BinarySequenceMethodTranslator { internal IntersectTranslator() : base(SequenceMethod.Intersect) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Intersect(left, right); } } private sealed class ExceptTranslator : BinarySequenceMethodTranslator { internal ExceptTranslator() : base(SequenceMethod.Except) { } protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right) { return parent.Except(left, right); } } private abstract class AggregateTranslator : SequenceMethodTranslator { private readonly string _functionName; private readonly bool _takesPredicate; protected AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods) : base(methods) { _takesPredicate = takesPredicate; _functionName = functionName; } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { bool isUnary = 1 == call.Arguments.Count; Debug.Assert(isUnary || 2 == call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Arguments[0]); // Facet information for the return type cannot help in determining the appropriate function overload // since no constraints on the return value are known. TypeUsage returnType = parent.GetValueLayerType(call.Type); LambdaExpression lambda = null; //We save the original operand for the optimized translation CqtExpression originalOperand = operand; if (!isUnary) { lambda = parent.GetLambdaExpression(call, 1); DbExpressionBinding sourceBinding; CqtExpression cqtLambda = parent.TranslateLambda(lambda, operand, out sourceBinding); if (_takesPredicate) { // treat the lambda as a filter operand = parent.Filter(sourceBinding, cqtLambda); } else { // treat the lambda as a selector operand = parent._commandTree.CreateProjectExpression(sourceBinding, cqtLambda); } } operand = WrapCollectionOperand(parent, operand, returnType); DbGroupExpressionBinding operandBinding = parent._commandTree.CreateGroupExpressionBinding(operand); EdmFunction function = FindFunction(parent, call, returnType); // create aggregate List > keys = new List >(0); // no key List > aggregates = new List >(1); const string aggregateName = "Aggregate"; aggregates.Add(new KeyValuePair (aggregateName, // name is arbitrary (there is only one in this context) parent._commandTree.CreateFunctionAggregate(function, operandBinding.GroupVariable))); DbGroupByExpression aggregate = parent._commandTree.CreateGroupByExpression( operandBinding, keys, aggregates); DbExpressionBinding aggregateBinding = parent._commandTree.CreateExpressionBinding(aggregate); // project result DbPropertyExpression property = parent._commandTree.CreatePropertyExpression( aggregateName, aggregateBinding.Variable); DbProjectExpression projection = parent._commandTree.CreateProjectExpression( aggregateBinding, parent.AlignTypes(property, call.Type)); // return a single element to represent the projection DbElementExpression element = parent._commandTree.CreateElementExpression(projection); // Try to create and log an optimized translation TryCreateOptimizedTranslation(parent, lambda, originalOperand, function, element); return element; } // If the function is over a group by, it tries to incorporate the aggregate function into the group by. // If it does, it gives the aggregate an alias and it returns it through private void TryCreateOptimizedTranslation(ExpressionConverter parent, LambdaExpression lambda, CqtExpression operand, EdmFunction function, CqtExpression originalTranslation) { //Aggregates that take predicates as arguments cannot be incorporated into a group by if (_takesPredicate && (lambda != null)) { return; } //Check whether the operand is a property over an output of grouping if (operand.ExpressionKind != DbExpressionKind.Property) { return; } DbPropertyExpression propertyExpression = (DbPropertyExpression)operand; if (propertyExpression.Instance.ExpressionKind != DbExpressionKind.VariableReference) { return; } DbVariableReferenceExpression inputVarRef = (DbVariableReferenceExpression)propertyExpression.Instance; //If the input corresponding to the var ref has an optimized translation, generate an alternate translation for this as well. CqtExpression input; if (!parent._variableNameToInputExpression.TryGetValue(inputVarRef.VariableName, out input)) { return; } DbGroupByTemplate optimizedTranslationOfInput; if (!parent._groupByDefaultToOptimizedTranslationMap.TryGetValue(input, out optimizedTranslationOfInput)) { return; } Debug.Assert(TypeSemantics.IsCollectionType(function.Parameters[0].TypeUsage), "Aggregates should always have collection arguments"); TypeUsage elementType = TypeHelpers.GetElementTypeUsage(function.Parameters[0].TypeUsage); // The aggregate can be added to the list of aggregates. CqtExpression aggregateArgument = optimizedTranslationOfInput.Input.GroupVariable; if (lambda != null) { aggregateArgument = parent.TranslateLambda(lambda, aggregateArgument); } string optimizedTranslationAlias = String.Format(CultureInfo.InvariantCulture, "Aggregate{0}", optimizedTranslationOfInput.Aggregates.Count); optimizedTranslationOfInput.Aggregates.Add(new KeyValuePair (optimizedTranslationAlias, parent._commandTree.CreateFunctionAggregate(function, WrapNonCollectionOperand(parent, aggregateArgument, elementType)))); //log the alias parent._aggregateDefaultTranslationToOptimizedTranslationInfoMap.Add(originalTranslation, new KeyValuePair (optimizedTranslationOfInput, optimizedTranslationAlias)); } // If necessary, wraps the operand to ensure the appropriate aggregate overload is called protected virtual CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // check if the operand needs to be wrapped to ensure the correct function overload is called if (!TypeUsageEquals(returnType, ((CollectionType)operand.ResultType.EdmType).TypeUsage)) { DbExpressionBinding operandCastBinding = parent._commandTree.CreateExpressionBinding(operand); DbProjectExpression operandCastProjection = parent._commandTree.CreateProjectExpression( operandCastBinding, parent._commandTree.CreateCastExpression(operandCastBinding.Variable, returnType)); operand = operandCastProjection; } return operand; } // If necessary, wraps the operand to ensure the appropriate aggregate overload is called protected virtual CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { if (!TypeUsageEquals(returnType, operand.ResultType)) { operand = parent._commandTree.CreateCastExpression(operand, returnType); } return operand; } // Finds the best function overload given the expected return type protected virtual EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType) { List argTypes = new List (1); // In general, we use the return type as the parameter type to align LINQ semantics // with SQL semantics, and avoid apparent loss of precision for some LINQ aggregate operators. // (e.g., AVG(1, 2) = 2.0, AVG((double)1, (double)2)) = 1.5) argTypes.Add(argumentType); return parent.FindCanonicalFunction(_functionName, argTypes, true /* isGroupAggregateFunction */, call); } } private sealed class MaxTranslator : AggregateTranslator { internal MaxTranslator() : base("MAX", false, SequenceMethod.Max, SequenceMethod.MaxSelector, SequenceMethod.MaxInt, SequenceMethod.MaxIntSelector, SequenceMethod.MaxDecimal, SequenceMethod.MaxDecimalSelector, SequenceMethod.MaxDouble, SequenceMethod.MaxDoubleSelector, SequenceMethod.MaxLong, SequenceMethod.MaxLongSelector, SequenceMethod.MaxSingle, SequenceMethod.MaxSingleSelector, SequenceMethod.MaxNullableDecimal, SequenceMethod.MaxNullableDecimalSelector, SequenceMethod.MaxNullableDouble, SequenceMethod.MaxNullableDoubleSelector, SequenceMethod.MaxNullableInt, SequenceMethod.MaxNullableIntSelector, SequenceMethod.MaxNullableLong, SequenceMethod.MaxNullableLongSelector, SequenceMethod.MaxNullableSingle, SequenceMethod.MaxNullableSingleSelector) { } } private sealed class MinTranslator : AggregateTranslator { internal MinTranslator() : base("MIN", false, SequenceMethod.Min, SequenceMethod.MinSelector, SequenceMethod.MinDecimal, SequenceMethod.MinDecimalSelector, SequenceMethod.MinDouble, SequenceMethod.MinDoubleSelector, SequenceMethod.MinInt, SequenceMethod.MinIntSelector, SequenceMethod.MinLong, SequenceMethod.MinLongSelector, SequenceMethod.MinNullableDecimal, SequenceMethod.MinSingle, SequenceMethod.MinSingleSelector, SequenceMethod.MinNullableDecimalSelector, SequenceMethod.MinNullableDouble, SequenceMethod.MinNullableDoubleSelector, SequenceMethod.MinNullableInt, SequenceMethod.MinNullableIntSelector, SequenceMethod.MinNullableLong, SequenceMethod.MinNullableLongSelector, SequenceMethod.MinNullableSingle, SequenceMethod.MinNullableSingleSelector) { } } private sealed class AverageTranslator : AggregateTranslator { internal AverageTranslator() : base("AVG", false, SequenceMethod.AverageDecimal, SequenceMethod.AverageDecimalSelector, SequenceMethod.AverageDouble, SequenceMethod.AverageDoubleSelector, SequenceMethod.AverageInt, SequenceMethod.AverageIntSelector, SequenceMethod.AverageLong, SequenceMethod.AverageLongSelector, SequenceMethod.AverageSingle, SequenceMethod.AverageSingleSelector, SequenceMethod.AverageNullableDecimal, SequenceMethod.AverageNullableDecimalSelector, SequenceMethod.AverageNullableDouble, SequenceMethod.AverageNullableDoubleSelector, SequenceMethod.AverageNullableInt, SequenceMethod.AverageNullableIntSelector, SequenceMethod.AverageNullableLong, SequenceMethod.AverageNullableLongSelector, SequenceMethod.AverageNullableSingle, SequenceMethod.AverageNullableSingleSelector) { } } private sealed class SumTranslator : AggregateTranslator { internal SumTranslator() : base("SUM", false, SequenceMethod.SumDecimal, SequenceMethod.SumDecimalSelector, SequenceMethod.SumDouble, SequenceMethod.SumDoubleSelector, SequenceMethod.SumInt, SequenceMethod.SumIntSelector, SequenceMethod.SumLong, SequenceMethod.SumLongSelector, SequenceMethod.SumSingle, SequenceMethod.SumSingleSelector, SequenceMethod.SumNullableDecimal, SequenceMethod.SumNullableDecimalSelector, SequenceMethod.SumNullableDouble, SequenceMethod.SumNullableDoubleSelector, SequenceMethod.SumNullableInt, SequenceMethod.SumNullableIntSelector, SequenceMethod.SumNullableLong, SequenceMethod.SumNullableLongSelector, SequenceMethod.SumNullableSingle, SequenceMethod.SumNullableSingleSelector) { } } private abstract class CountTranslatorBase : AggregateTranslator { protected CountTranslatorBase(string functionName, params SequenceMethod[] methods) : base(functionName, true, methods) { } protected override CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // always count a constant value DbExpressionBinding operandBinding = parent._commandTree.CreateExpressionBinding(operand); DbProjectExpression constantProject = parent._commandTree.CreateProjectExpression( operandBinding, parent._commandTree.CreateTrueExpression()); return constantProject; } protected override CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType) { // always count a constant value DbExpression constantExpression = parent._commandTree.CreateConstantExpression(1); if (!TypeUsageEquals(constantExpression.ResultType, returnType)) { constantExpression = parent._commandTree.CreateCastExpression(constantExpression, returnType); } return constantExpression; } protected override EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call, TypeUsage argumentType) { // For most ELinq aggregates, the argument type is the return type. For "count", the // argument type is always Boolean, since we project a constant Boolean value in WrapCollectionOperand. TypeUsage booleanTypeUsage = parent._commandTree.CreateTrueExpression().ResultType; return base.FindFunction(parent, call, booleanTypeUsage); } } private sealed class CountTranslator : CountTranslatorBase { internal CountTranslator() : base("COUNT", SequenceMethod.Count, SequenceMethod.CountPredicate) { } } private sealed class LongCountTranslator : CountTranslatorBase { internal LongCountTranslator() : base("BIGCOUNT", SequenceMethod.LongCount, SequenceMethod.LongCountPredicate) { } } private abstract class UnarySequenceMethodTranslator : SequenceMethodTranslator { protected UnarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { if (null != call.Object) { // instance method Debug.Assert(0 <= call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Object); return TranslateUnary(parent, operand, call); } else { // static extension method Debug.Assert(1 <= call.Arguments.Count); CqtExpression operand = parent.TranslateSet(call.Arguments[0]); return TranslateUnary(parent, operand, call); } } protected abstract CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call); } private sealed class PassthroughTranslator : UnarySequenceMethodTranslator { internal PassthroughTranslator() : base(SequenceMethod.AsQueryableGeneric, SequenceMethod.AsQueryable, SequenceMethod.AsEnumerable) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // make sure the operand has collection type to avoid treating (for instance) String as a // sub-query if (TypeSemantics.IsCollectionType(operand.ResultType)) { return operand; } else { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedPassthrough( call.Method.Name, operand.ResultType.EdmType.Name)); } } } private sealed class OfTypeTranslator : UnarySequenceMethodTranslator { internal OfTypeTranslator() : base(SequenceMethod.OfType) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { Type clrType = call.Method.GetGenericArguments()[0]; TypeUsage modelType; // If the model type does not exist in the perspective or is not either an EntityType // or a ComplexType, fail - OfType() is not a valid operation on scalars, // enumerations, collections, etc. if (!parent.TryGetValueLayerType(clrType, out modelType) || !(TypeSemantics.IsEntityType(modelType) || TypeSemantics.IsComplexType(modelType))) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_InvalidOfTypeResult(DescribeClrType(clrType))); } // Create an of type expression to filter the original query to include // only those results that are of the specified type. CqtExpression ofTypeExpression = parent.OfType(operand, modelType); return ofTypeExpression; } } private sealed class DistinctTranslator : UnarySequenceMethodTranslator { internal DistinctTranslator() : base(SequenceMethod.Distinct) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { return parent.Distinct(operand); } } private sealed class AnyTranslator : UnarySequenceMethodTranslator { internal AnyTranslator() : base(SequenceMethod.Any) { } protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // "Any" is equivalent to "exists". return parent._commandTree.CreateNotExpression( parent._commandTree.CreateIsEmptyExpression(operand)); } } private abstract class OneLambdaTranslator : SequenceMethodTranslator { internal OneLambdaTranslator(params SequenceMethod[] methods) : base(methods) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { CqtExpression source; DbExpressionBinding sourceBinding; CqtExpression lambda; return Translate(parent, call, out source, out sourceBinding, out lambda); } // Helper method for tranlsation protected CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda) { Debug.Assert(2 <= call.Arguments.Count); // translate source source = parent.TranslateExpression(call.Arguments[0]); // translate lambda expression LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); lambda = parent.TranslateLambda(lambdaExpression, source, out sourceBinding); return TranslateOneLambda(parent, sourceBinding, lambda); } protected abstract CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda); } private sealed class AnyPredicateTranslator : OneLambdaTranslator { internal AnyPredicateTranslator() : base(SequenceMethod.AnyPredicate) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent._commandTree.CreateAnyExpression(sourceBinding, lambda); } } private sealed class AllTranslator : OneLambdaTranslator { internal AllTranslator() : base(SequenceMethod.All) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent._commandTree.CreateAllExpression(sourceBinding, lambda); } } private sealed class WhereTranslator : OneLambdaTranslator { internal WhereTranslator() : base(SequenceMethod.Where) { } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Filter(sourceBinding, lambda); } } private sealed class SelectTranslator : OneLambdaTranslator { internal SelectTranslator() : base(SequenceMethod.Select) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { CqtExpression source; DbExpressionBinding sourceBinding; CqtExpression lambda; CqtExpression result = Translate(parent, call, out source, out sourceBinding, out lambda); // If the select if over a GroupBy, check whether an optimized translation can be produced // The default translation is to 'manually' do group by, the optimized is translating into a // DbGroupByExpression CqtExpression rewrittenExpression; if (parent.TryRewrite(source, sourceBinding, lambda, out rewrittenExpression)) { result = rewrittenExpression; } return result; } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Project(sourceBinding, lambda); } } private abstract class FirstTranslatorBase : UnarySequenceMethodTranslator { protected FirstTranslatorBase(bool orDefault, params SequenceMethod[] methods) : base(methods) { _orDefault = orDefault; } private readonly bool _orDefault; protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call) { // Apply a Limit(1) expression to restrict the result to a single element // (this returns an empty set if the input set is initially empty). CqtExpression result = parent.Limit(operand, parent._commandTree.CreateConstantExpression(1)); // If this First() or FirstOrDefault() operation is the root of the query, // then the evaluation is performed in the client over the resulting set, // to provide the same semantics as Linq to Objects. Otherwise, an Element // expression is applied to retrieve the single element (or null, if empty) // from the output set. if (!parent.IsQueryRoot(call)) { if (_orDefault) { result = parent._commandTree.CreateElementExpression(result); result = AddDefaultCase(parent, result, call.Type); } else { // First is not allowed anywhere other than the root of the query, // since First() operations logically executed in the store will not // throw an exception if the input set is empty (typically they will // simply produce a null result). throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); } } // Span is preserved over First/FirstOrDefault with or without a predicate Span inputSpan = null; if (parent.TryGetSpan(operand, out inputSpan)) { parent.AddSpanMapping(result, inputSpan); } return result; } internal static CqtExpression AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType) { // Retrieve default value. object defaultValue = TypeSystem.GetDefaultValue(elementType); if (null == defaultValue) { // Already null, which is the implicit default for DbElementExpression return element; } // Otherwise, use the default value for the type List whenExpressions = new List (1); whenExpressions.Add(parent.CreateIsNullExpression(element, elementType)); List thenExpressions = new List (1); thenExpressions.Add(parent._commandTree.CreateConstantExpression(defaultValue)); DbCaseExpression caseExpression = parent._commandTree.CreateCaseExpression( whenExpressions, thenExpressions, element); return caseExpression; } } private sealed class FirstTranslator : FirstTranslatorBase { internal FirstTranslator() : base(false, SequenceMethod.First) { } } private sealed class FirstOrDefaultTranslator : FirstTranslatorBase { internal FirstOrDefaultTranslator() : base(true, SequenceMethod.FirstOrDefault) { } } private abstract class FirstPredicateTranslatorBase : OneLambdaTranslator { protected FirstPredicateTranslatorBase(bool orDefault, params SequenceMethod[] methods) : base(methods) { _orDefault = orDefault; } private readonly bool _orDefault; internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // Convert the input set and the predicate into a filter expression CqtExpression input = base.Translate(parent, call); // If this First/FirstOrDefault is the root of the query, // then the actual result will be produced by evaluated by // calling First() or FirstOrDefault() on the filtered input set, // which is limited to at most one element by applying a Limit(1) expression. if (parent.IsQueryRoot(call)) { // Calling ExpressionConverter.Limit propagates the Span. return parent.Limit(input, parent._commandTree.CreateConstantExpression(1)); } else { if (_orDefault) { CqtExpression element = parent._commandTree.CreateElementExpression(input); element = FirstTranslatorBase.AddDefaultCase(parent, element, call.Type); // Span is preserved over First/FirstOrDefault with or without a predicate Span inputSpan = null; if (parent.TryGetSpan(input, out inputSpan)) { parent.AddSpanMapping(element, inputSpan); } return element; } else { // First (with or without a predicate) is not allowed anywhere other // than the root of the query, since First(predicate) operations // logically executed in the store will not throw an exception if the // input set is empty (typically they will simply produce a null result). throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst); } } } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { return parent.Filter(sourceBinding, lambda); } } private sealed class FirstPredicateTranslator : FirstPredicateTranslatorBase { internal FirstPredicateTranslator() : base(false, SequenceMethod.FirstPredicate) { } } private sealed class FirstOrDefaultPredicateTranslator : FirstPredicateTranslatorBase { internal FirstOrDefaultPredicateTranslator() : base(true, SequenceMethod.FirstOrDefaultPredicate) { } } private sealed class SelectManyTranslator : OneLambdaTranslator { internal SelectManyTranslator() : base(SequenceMethod.SelectMany, SequenceMethod.SelectManyResultSelector) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { bool hasSelector = 3 == call.Arguments.Count; CqtExpression crossApply = base.Translate(parent, call); // perform a cross apply to implement the core logic for SelectMany (this translates the collection selector): // SelectMany(i, Func> collectionSelector) => // i CROSS APPLY collectionSelector(i) // The cross-apply yields a collection from which we yield either the right hand side (when // no explicit resultSelector is given) or over which we apply the resultSelector Lambda expression. DbExpressionBinding crossApplyBinding = parent._commandTree.CreateExpressionBinding(crossApply); RowType crossApplyRowType = (RowType)(crossApplyBinding.Variable.ResultType.EdmType); CqtExpression projectRight = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[1], crossApplyBinding.Variable); CqtExpression resultProjection; if (hasSelector) { CqtExpression projectLeft = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[0], crossApplyBinding.Variable); LambdaExpression resultSelector = parent.GetLambdaExpression(call, 2); // add the left and right projection terms to the binding context parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[0], projectLeft), new Binding(resultSelector.Parameters[1], projectRight)); // translate the result selector resultProjection = parent.TranslateSet(resultSelector.Body); // pop binding context parent._bindingContext.PopBindingScope(); } else { // project out the right hand side of the apply resultProjection = projectRight; } // wrap result projection in project expression return parent._commandTree.CreateProjectExpression(crossApplyBinding, resultProjection); } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { // elements of the inner selector should be used lambda = parent.NormalizeSetSource(lambda); DbExpressionBinding applyBinding = parent._commandTree.CreateExpressionBinding(lambda); DbApplyExpression crossApply = parent._commandTree.CreateCrossApplyExpression(sourceBinding, applyBinding); return crossApply; } } private sealed class CastMethodTranslator : SequenceMethodTranslator { internal CastMethodTranslator() : base(SequenceMethod.Cast) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // Translate source CqtExpression source = parent.TranslateSet(call.Arguments[0]); // Figure out the type to cast to Type toClrType = TypeSystem.GetElementType(call.Type); Type fromClrType = TypeSystem.GetElementType(call.Arguments[0].Type); // Get binding to the elements of the input source DbExpressionBinding binding = parent._commandTree.CreateExpressionBinding(source); CqtExpression cast = parent.CreateCastExpression(binding.Variable, toClrType, fromClrType); return parent._commandTree.CreateProjectExpression(binding, cast); } } private sealed class GroupByTranslator : SequenceMethodTranslator { internal GroupByTranslator() : base(SequenceMethod.GroupBy, SequenceMethod.GroupByElementSelector, SequenceMethod.GroupByElementSelectorResultSelector, SequenceMethod.GroupByResultSelector) { } // The default translation of GroupBy is: // SELECT d as Key, (SELECT VALUE g FROM source WHERE source.Key = d) as Group // FROM (SELECT DISTINCT source.Key) // // The optimized translation is simply creating a Cqt GroupByExpression internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod) { // translate source CqtExpression source = parent.TranslateSet(call.Arguments[0]); // translate key selector LambdaExpression keySelectorLinq = parent.GetLambdaExpression(call, 1); DbExpressionBinding sourceBinding; CqtExpression keySelector = parent.TranslateLambda(keySelectorLinq, source, out sourceBinding); // translate the key selector again in a different binding context (for the nested select) DbExpressionBinding nestedSourceBinding; CqtExpression nestedSelector = parent.TranslateLambda(keySelectorLinq, source, out nestedSourceBinding); // create distinct expression if (!TypeSemantics.IsEqualComparable(keySelector.ResultType)) { // to avoid confusing error message about the "distinct" type, pre-emptively raise an exception // about the group by key selector throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } CqtExpression distinct = parent.Distinct( parent._commandTree.CreateProjectExpression(sourceBinding, keySelector)); DbExpressionBinding distinctBinding = parent._commandTree.CreateExpressionBinding(distinct); // create group projection term DbFilterExpression groupKeyFilter = parent.Filter( nestedSourceBinding, parent.CreateEqualsExpression(nestedSelector, distinctBinding.Variable, EqualsPattern.PositiveNullEquality, keySelectorLinq.Type, keySelectorLinq.Type)); // interpret element selector if needed CqtExpression selection = groupKeyFilter; bool hasElementSelector = sequenceMethod == SequenceMethod.GroupByElementSelector || sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector; if (hasElementSelector) { LambdaExpression elementSelectorLinq = parent.GetLambdaExpression(call, 2); DbExpressionBinding elementSelectorSourceBinding; CqtExpression elementSelector = parent.TranslateLambda(elementSelectorLinq, selection, out elementSelectorSourceBinding); selection = parent._commandTree.CreateProjectExpression(elementSelectorSourceBinding, elementSelector); } // create top level projection List projectionTerms = new List (2); projectionTerms.Add(distinctBinding.Variable); projectionTerms.Add(selection); // build projection type with initializer information List properties = new List (2); properties.Add(new EdmProperty(KeyColumnName, projectionTerms[0].ResultType)); properties.Add(new EdmProperty(GroupColumnName, projectionTerms[1].ResultType)); InitializerMetadata initializerMetadata = InitializerMetadata.CreateGroupingInitializer( parent.EdmItemCollection, TypeSystem.GetElementType(call.Type)); RowType rowType = new RowType(properties, initializerMetadata); TypeUsage rowTypeUsage = TypeUsage.Create(rowType); CqtExpression topLevelProject = parent._commandTree.CreateProjectExpression(distinctBinding, parent._commandTree.CreateNewInstanceExpression(rowTypeUsage, projectionTerms)); if (!hasElementSelector) { //Create optimized translation for the GroupBy - simple GroupBy template DbGroupExpressionBinding groupByBinding; CqtExpression newKeySelector = parent.TranslateLambda(keySelectorLinq, source, out groupByBinding); DbGroupByTemplate groupByTemplate = new DbGroupByTemplate(groupByBinding); groupByTemplate.GroupKeys.Add(new KeyValuePair (KeyColumnName, newKeySelector)); parent._groupByDefaultToOptimizedTranslationMap.Add(topLevelProject, groupByTemplate); } var result = topLevelProject; // GroupBy may include a result selector; handle it result = ProcessResultSelector(parent, call, sequenceMethod, topLevelProject, result); return result; } private static DbExpression ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result) { // interpret result selector if needed LambdaExpression resultSelectorLinqExpression = null; if (sequenceMethod == SequenceMethod.GroupByResultSelector) { resultSelectorLinqExpression = parent.GetLambdaExpression(call, 2); } else if (sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector) { resultSelectorLinqExpression = parent.GetLambdaExpression(call, 3); } if (null != resultSelectorLinqExpression) { // selector maps (Key, Group) -> Result // push bindings for key and group DbExpressionBinding topLevelProjectBinding = parent._commandTree.CreateExpressionBinding(topLevelProject); parent._variableNameToInputExpression.Add(topLevelProjectBinding.VariableName, topLevelProject); DbPropertyExpression keyExpression = parent._commandTree.CreatePropertyExpression( KeyColumnName, topLevelProjectBinding.Variable); DbPropertyExpression groupExpression = parent._commandTree.CreatePropertyExpression( GroupColumnName, topLevelProjectBinding.Variable); parent._bindingContext.PushBindingScope( new Binding(resultSelectorLinqExpression.Parameters[0], keyExpression), new Binding(resultSelectorLinqExpression.Parameters[1], groupExpression)); // translate selector CqtExpression resultSelector = parent.TranslateExpression( resultSelectorLinqExpression.Body); result = parent._commandTree.CreateProjectExpression(topLevelProjectBinding, resultSelector); // see if the selector can be optimized CqtExpression rewrittenExpression; if (parent.TryRewrite(topLevelProject, topLevelProjectBinding, resultSelector, out rewrittenExpression)) { result = rewrittenExpression; } parent._bindingContext.PopBindingScope(); } return result; } internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Fail("unreachable code"); return null; } } private sealed class GroupJoinTranslator : SequenceMethodTranslator { internal GroupJoinTranslator() : base(SequenceMethod.GroupJoin) { } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { // o.GroupJoin(i, ok => outerKeySelector, ik => innerKeySelector, (o, i) => projection) // --> // SELECT projection(o, i) // FROM ( // SELECT o, (SELECT i FROM i WHERE o.outerKeySelector = i.innerKeySelector) as i // FROM o) // translate inputs CqtExpression outer = parent.TranslateSet(call.Arguments[0]); CqtExpression inner = parent.TranslateSet(call.Arguments[1]); // translate key selectors DbExpressionBinding outerBinding; DbExpressionBinding innerBinding; LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2); LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3); CqtExpression outerSelector = parent.TranslateLambda( outerLambda, outer, out outerBinding); CqtExpression innerSelector = parent.TranslateLambda( innerLambda, inner, out innerBinding); // create innermost SELECT i FROM i WHERE ... if (!TypeSemantics.IsEqualComparable(outerSelector.ResultType) || !TypeSemantics.IsEqualComparable(innerSelector.ResultType)) { throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name)); } CqtExpression nestedCollection = parent.Filter(innerBinding, parent.CreateEqualsExpression(outerSelector, innerSelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type)); // create "join" SELECT o, (nestedCollection) const string outerColumn = "o"; const string innerColumn = "i"; List > recordColumns = new List >(2); recordColumns.Add(new KeyValuePair (outerColumn, outerBinding.Variable)); recordColumns.Add(new KeyValuePair (innerColumn, nestedCollection)); CqtExpression joinProjection = parent._commandTree.CreateNewRowExpression(recordColumns); CqtExpression joinProject = parent._commandTree.CreateProjectExpression(outerBinding, joinProjection); DbExpressionBinding joinProjectBinding = parent._commandTree.CreateExpressionBinding(joinProject); // create property expressions for the outer and inner terms to bind to the parameters to the // group join selector CqtExpression outerProperty = parent._commandTree.CreatePropertyExpression(outerColumn, joinProjectBinding.Variable); CqtExpression innerProperty = parent._commandTree.CreatePropertyExpression(innerColumn, joinProjectBinding.Variable); // push the inner and the outer terms into the binding scope LambdaExpression linqSelector = parent.GetLambdaExpression(call, 4); parent._bindingContext.PushBindingScope( new Binding(linqSelector.Parameters[0], outerProperty), new Binding(linqSelector.Parameters[1], innerProperty)); // translate the selector CqtExpression selectorProject = parent.TranslateExpression(linqSelector.Body); // pop the binding scope parent._bindingContext.PopBindingScope(); // create the selector projection CqtExpression selector = parent._commandTree.CreateProjectExpression(joinProjectBinding, selectorProject); return selector; } } private abstract class OrderByTranslatorBase : OneLambdaTranslator { private readonly bool _ascending; protected OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods) : base(methods) { _ascending = ascending; } protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda) { List keys = new List (1); DbSortClause sortSpec = parent._commandTree.CreateSortClause(lambda, _ascending); keys.Add(sortSpec); DbSortExpression sort = parent.Sort(sourceBinding, keys); return sort; } } private sealed class OrderByTranslator : OrderByTranslatorBase { internal OrderByTranslator() : base(true, SequenceMethod.OrderBy) { } } private sealed class OrderByDescendingTranslator : OrderByTranslatorBase { internal OrderByDescendingTranslator() : base(false, SequenceMethod.OrderByDescending) { } } // Note: because we need to "push-down" the expression binding for ThenBy, this class // does not inherit from OneLambdaTranslator, although it is similar. private abstract class ThenByTranslatorBase : SequenceMethodTranslator { private readonly bool _ascending; protected ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods) : base(methods) { _ascending = ascending; } internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call) { Debug.Assert(2 == call.Arguments.Count); CqtExpression source = parent.TranslateSet(call.Arguments[0]); if (DbExpressionKind.Sort != source.ExpressionKind) { throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.ELinq_ThenByDoesNotFollowOrderBy); } DbSortExpression sortExpression = (DbSortExpression)source; // retrieve information about existing sort DbExpressionBinding binding = sortExpression.Input; // get information on new sort term LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1); ParameterExpression parameter = lambdaExpression.Parameters[0]; // push-down the binding scope information and translate the new sort key parent._bindingContext.PushBindingScope(new Binding(parameter, binding.Variable)); CqtExpression lambda = parent.TranslateExpression(lambdaExpression.Body); parent._bindingContext.PopBindingScope(); // create a new sort expression List keys = new List (sortExpression.SortOrder); keys.Add(new DbSortClause(lambda, _ascending, null)); sortExpression = parent.Sort(binding, keys); return sortExpression; } } private sealed class ThenByTranslator : ThenByTranslatorBase { internal ThenByTranslator() : base(true, SequenceMethod.ThenBy) { } } private sealed class ThenByDescendingTranslator : ThenByTranslatorBase { internal ThenByDescendingTranslator() : base(false, SequenceMethod.ThenByDescending) { } } #endregion } } } // File provided for Reference Use Only by Microsoft Corporation (c) 2007.
Link Menu
This book is available now!
Buy at Amazon US or
Buy at Amazon UK
- StylusButton.cs
- ModelServiceImpl.cs
- TextTreeTextElementNode.cs
- TemplateBamlTreeBuilder.cs
- EventData.cs
- GridItemCollection.cs
- ReceiveErrorHandling.cs
- Int32Rect.cs
- Operator.cs
- ColumnWidthChangingEvent.cs
- ReadOnlyObservableCollection.cs
- MeshGeometry3D.cs
- RsaSecurityKey.cs
- ToggleProviderWrapper.cs
- ManifestSignatureInformation.cs
- DBConcurrencyException.cs
- RoutedPropertyChangedEventArgs.cs
- XmlSchemaCompilationSettings.cs
- PeerTransportSecuritySettings.cs
- ColorConvertedBitmapExtension.cs
- MultiView.cs
- AnimationException.cs
- WindowsListViewGroupSubsetLink.cs
- HandleExceptionArgs.cs
- PersonalizationProviderHelper.cs
- SettingsPropertyWrongTypeException.cs
- Stack.cs
- WebPartVerbCollection.cs
- CodeThrowExceptionStatement.cs
- FileInfo.cs
- LinkLabelLinkClickedEvent.cs
- ApplicationSettingsBase.cs
- NetworkInformationPermission.cs
- DbTypeMap.cs
- MimeMultiPart.cs
- Size.cs
- ResourceReferenceKeyNotFoundException.cs
- TypeDelegator.cs
- DayRenderEvent.cs
- SqlUserDefinedTypeAttribute.cs
- GZipDecoder.cs
- DBConnection.cs
- DigitShape.cs
- Exceptions.cs
- MetadataPropertyAttribute.cs
- ToolStripGripRenderEventArgs.cs
- NumberFormatInfo.cs
- CustomErrorCollection.cs
- SqlNodeTypeOperators.cs
- AutomationTextAttribute.cs
- ChannelDispatcherCollection.cs
- HttpModulesSection.cs
- NativeRecognizer.cs
- PathGradientBrush.cs
- SupportingTokenChannel.cs
- WebPartZone.cs
- TcpConnectionPoolSettings.cs
- MarkupObject.cs
- ResolveRequestResponseAsyncResult.cs
- CharStorage.cs
- TextElementEnumerator.cs
- RequestSecurityTokenResponseCollection.cs
- CodeGroup.cs
- SelectionUIHandler.cs
- WmlListAdapter.cs
- GuidConverter.cs
- ReferenceEqualityComparer.cs
- LicenseContext.cs
- SingleAnimation.cs
- QilParameter.cs
- LayoutEditorPart.cs
- RegionInfo.cs
- SqlMethodAttribute.cs
- AddInActivator.cs
- DateTimeConverter2.cs
- CurrentTimeZone.cs
- LinqToSqlWrapper.cs
- NameGenerator.cs
- ClrPerspective.cs
- HandlerMappingMemo.cs
- FlowDocumentPaginator.cs
- ApplicationGesture.cs
- DataSourceProvider.cs
- OdbcEnvironmentHandle.cs
- TextLineResult.cs
- ApplicationManager.cs
- WeakEventManager.cs
- ExtendedPropertyCollection.cs
- Path.cs
- UnicastIPAddressInformationCollection.cs
- AutomationProperty.cs
- PermissionListSet.cs
- BitmapImage.cs
- Util.cs
- XmlAutoDetectWriter.cs
- CodeAccessSecurityEngine.cs
- EntityDesignerDataSourceView.cs
- InvariantComparer.cs
- Literal.cs
- DesignerAdapterAttribute.cs