Code:
/ Dotnetfx_Vista_SP2 / Dotnetfx_Vista_SP2 / 8.0.50727.4016 / DEVDIV / depot / DevDiv / releases / Orcas / QFE / ndp / fx / src / DataEntity / System / Data / Objects / ELinq / MethodCallTranslator.cs / 1 / MethodCallTranslator.cs
//----------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
// @owner [....], [....]
//---------------------------------------------------------------------
using System.Data.Common.CommandTrees;
using System.Collections.Generic;
using CqtExpression = System.Data.Common.CommandTrees.DbExpression;
using LinqExpression = System.Linq.Expressions.Expression;
using System.Diagnostics;
using System.Data.Metadata.Edm;
using System.Linq.Expressions;
using System.Reflection;
using System.Linq;
using System.Data.Entity;
using System.Data.Common;
using System.Globalization;
namespace System.Data.Objects.ELinq
{
internal sealed partial class ExpressionConverter
{
///
/// Translates System.Linq.Expression.MethodCallExpression to System.Data.Common.CommandTrees.DbExpression
///
private sealed class MethodCallTranslator : TypedTranslator
{
internal MethodCallTranslator()
: base(ExpressionType.Call) { }
protected override CqtExpression TypedTranslate(ExpressionConverter parent, MethodCallExpression linq)
{
// check if this is a known sequence method
SequenceMethod sequenceMethod;
SequenceMethodTranslator sequenceTranslator;
if (ReflectionUtil.TryIdentifySequenceMethod(linq.Method, out sequenceMethod) &&
s_sequenceTranslators.TryGetValue(sequenceMethod, out sequenceTranslator))
{
return sequenceTranslator.Translate(parent, linq, sequenceMethod);
}
// check if this is a known method
CallTranslator callTranslator;
if (TryGetCallTranslator(linq.Method, out callTranslator))
{
return callTranslator.Translate(parent, linq);
}
// check if this is an ObjectQuery<> builder method
Type declaringType = linq.Method.DeclaringType;
if (linq.Method.IsPublic &&
null != declaringType &&
declaringType.IsGenericType &&
typeof(ObjectQuery<>) == declaringType.GetGenericTypeDefinition())
{
ObjectQueryCallTranslator builderTranslator;
if (s_objectQueryTranslators.TryGetValue(linq.Method.Name, out builderTranslator))
{
return builderTranslator.Translate(parent, linq);
}
}
// fall back on the default translator
return s_defaultTranslator.Translate(parent, linq);
}
#region Static members and initializers
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
// initialize fall-back translator
private static readonly CallTranslator s_defaultTranslator = new DefaultTranslator();
private static readonly Dictionary s_methodTranslators = InitializeMethodTranslators();
private static readonly Dictionary s_sequenceTranslators = InitializeSequenceMethodTranslators();
private static readonly Dictionary s_objectQueryTranslators = InitializeObjectQueryTranslators();
private static bool s_vbMethodsInitialized;
private static readonly object s_vbInitializerLock = new object();
private static Dictionary InitializeMethodTranslators()
{
// initialize translators for specific methods (e.g., Int32.op_Equality)
Dictionary methodTranslators = new Dictionary();
foreach (CallTranslator translator in GetCallTranslators())
{
foreach (MethodInfo method in translator.Methods)
{
methodTranslators.Add(method, translator);
}
}
return methodTranslators;
}
private static Dictionary InitializeSequenceMethodTranslators()
{
// initialize translators for sequence methods (e.g., Sequence.Select)
Dictionary sequenceTranslators = new Dictionary();
foreach (SequenceMethodTranslator translator in GetSequenceMethodTranslators())
{
foreach (SequenceMethod method in translator.Methods)
{
sequenceTranslators.Add(method, translator);
}
}
return sequenceTranslators;
}
private static Dictionary InitializeObjectQueryTranslators()
{
// initialize translators for object query methods (e.g. ObjectQuery.OfType(), ObjectQuery.Include(string) )
Dictionary objectQueryCallTranslators = new Dictionary(StringComparer.Ordinal);
foreach (ObjectQueryCallTranslator translator in GetObjectQueryCallTranslators())
{
objectQueryCallTranslators[translator.MethodName] = translator;
}
return objectQueryCallTranslators;
}
///
/// Tries to get a translator for the given method info.
/// If the given method info corresponds to a Visual Basic property,
/// it also initializes the Visual Basic translators if they have not been initialized
///
///
///
///
private static bool TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator)
{
if (s_methodTranslators.TryGetValue(methodInfo, out callTranslator))
{
return true;
}
// check if this is the visual basic assembly
if (s_visualBasicAssemblyFullName == methodInfo.DeclaringType.Assembly.FullName)
{
lock (s_vbInitializerLock)
{
if (!s_vbMethodsInitialized)
{
InitializeVBMethods(methodInfo.DeclaringType.Assembly);
s_vbMethodsInitialized = true;
}
// try again
return s_methodTranslators.TryGetValue(methodInfo, out callTranslator);
}
}
callTranslator = null;
return false;
}
private static void InitializeVBMethods(Assembly vbAssembly)
{
Debug.Assert(!s_vbMethodsInitialized);
foreach (CallTranslator translator in GetVisualBasicCallTranslators(vbAssembly))
{
foreach (MethodInfo method in translator.Methods)
{
s_methodTranslators.Add(method, translator);
}
}
}
private static IEnumerable GetVisualBasicCallTranslators(Assembly vbAssembly)
{
yield return new VBCanonicalFunctionDefaultTranslator(vbAssembly);
yield return new VBCanonicalFunctionRenameTranslator(vbAssembly);
yield return new VBDatePartTranslator(vbAssembly);
}
private static IEnumerable GetCallTranslators()
{
yield return new CanonicalFunctionDefaultTranslator();
yield return new ContainsTranslator();
yield return new StartsWithTranslator();
yield return new EndsWithTranslator();
yield return new IndexOfTranslator();
yield return new SubstringTranslator();
yield return new RemoveTranslator();
yield return new InsertTranslator();
yield return new IsNullOrEmptyTranslator();
yield return new StringConcatTranslator();
yield return new TrimStartTranslator();
yield return new TrimEndTranslator();
}
private static IEnumerable GetSequenceMethodTranslators()
{
yield return new ConcatTranslator();
yield return new UnionTranslator();
yield return new IntersectTranslator();
yield return new ExceptTranslator();
yield return new DistinctTranslator();
yield return new WhereTranslator();
yield return new SelectTranslator();
yield return new OrderByTranslator();
yield return new OrderByDescendingTranslator();
yield return new ThenByTranslator();
yield return new ThenByDescendingTranslator();
yield return new SelectManyTranslator();
yield return new AnyTranslator();
yield return new AnyPredicateTranslator();
yield return new AllTranslator();
yield return new JoinTranslator();
yield return new GroupByTranslator();
yield return new MaxTranslator();
yield return new MinTranslator();
yield return new AverageTranslator();
yield return new SumTranslator();
yield return new CountTranslator();
yield return new LongCountTranslator();
yield return new CastMethodTranslator();
yield return new GroupJoinTranslator();
yield return new OfTypeTranslator();
yield return new SingleTranslatorNotSupported();
yield return new PassthroughTranslator();
yield return new FirstTranslator();
yield return new FirstPredicateTranslator();
yield return new FirstOrDefaultTranslator();
yield return new FirstOrDefaultPredicateTranslator();
yield return new TakeTranslator();
yield return new SkipTranslator();
}
private static IEnumerable GetObjectQueryCallTranslators()
{
yield return new ObjectQueryBuilderDistinctTranslator();
yield return new ObjectQueryBuilderExceptTranslator();
yield return new ObjectQueryBuilderFirstTranslator();
yield return new ObjectQueryIncludeTranslator();
yield return new ObjectQueryBuilderIntersectTranslator();
yield return new ObjectQueryBuilderOfTypeTranslator();
yield return new ObjectQueryBuilderUnionTranslator();
}
#endregion
#region Method translators
private abstract class CallTranslator
{
private readonly IEnumerable _methods;
protected CallTranslator(params MethodInfo[] methods) { _methods = methods; }
protected CallTranslator(IEnumerable methods) { _methods = methods; }
internal IEnumerable Methods { get { return _methods; } }
internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call);
public override string ToString()
{
return GetType().Name;
}
}
private abstract class ObjectQueryCallTranslator : CallTranslator
{
private readonly string _methodName;
protected ObjectQueryCallTranslator(string methodName)
{
_methodName = methodName;
}
internal string MethodName { get { return _methodName; } }
}
private abstract class ObjectQueryBuilderCallTranslator : ObjectQueryCallTranslator
{
private readonly SequenceMethodTranslator _translator;
protected ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent)
: base(methodName)
{
bool translatorFound = s_sequenceTranslators.TryGetValue(sequenceEquivalent, out _translator);
Debug.Assert(translatorFound, "Translator not found for " + sequenceEquivalent.ToString());
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return _translator.Translate(parent, call);
}
}
private sealed class ObjectQueryBuilderUnionTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderUnionTranslator()
: base("Union", SequenceMethod.Union)
{
}
}
private sealed class ObjectQueryBuilderIntersectTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderIntersectTranslator()
: base("Intersect", SequenceMethod.Intersect)
{
}
}
private sealed class ObjectQueryBuilderExceptTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderExceptTranslator()
: base("Except", SequenceMethod.Except)
{
}
}
private sealed class ObjectQueryBuilderDistinctTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderDistinctTranslator()
: base("Distinct", SequenceMethod.Distinct)
{
}
}
private sealed class ObjectQueryBuilderOfTypeTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderOfTypeTranslator()
: base("OfType", SequenceMethod.OfType)
{
}
}
private sealed class ObjectQueryBuilderFirstTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderFirstTranslator()
: base("First", SequenceMethod.First)
{
}
}
private sealed class ObjectQueryIncludeTranslator : ObjectQueryCallTranslator
{
internal ObjectQueryIncludeTranslator()
: base("Include")
{
}
internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(string)), "Invalid Include arguments?");
CqtExpression queryExpression = parent.TranslateExpression(call.Object);
Span span;
if (!parent.TryGetSpan(queryExpression, out span))
{
span = null;
}
CqtExpression arg = parent.TranslateExpression(call.Arguments[0]);
string includePath = null;
if (arg.ExpressionKind == DbExpressionKind.Constant)
{
includePath = (string)((DbConstantExpression)arg).Value;
}
else
{
// The 'Include' method implementation on ELinqQueryState creates
// a method call expression with a string constant argument taking
// the value of the string argument passed to ObjectQuery.Include,
// and so this is the only supported pattern here.
throw EntityUtil.NotSupported(Entity.Strings.ELinq_UnsupportedInclude);
}
return parent.AddSpanMapping(queryExpression, Span.IncludeIn(span, includePath));
}
}
private sealed class DefaultTranslator : CallTranslator
{
internal DefaultTranslator() : base() { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
MethodInfo suggestedMethodInfo;
if (TryGetAlternativeMethod(call.Method, out suggestedMethodInfo))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethodSuggestedAlternative(call.Method, suggestedMethodInfo));
}
//The default error message
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethod(call.Method));
}
#region Static Members
private static readonly Dictionary s_alternativeMethods = InitializeAlternateMethodInfos();
private static bool s_vbMethodsInitialized;
private static readonly object s_vbInitializerLock = new object();
///
/// Tries to check whether there is an alternative method suggested insted of the given unsupported one.
///
///
///
///
private static bool TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo)
{
if (s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo))
{
return true;
}
// check if this is the visual basic assembly
if (s_visualBasicAssemblyFullName == originalMethodInfo.DeclaringType.Assembly.FullName)
{
lock (s_vbInitializerLock)
{
if (!s_vbMethodsInitialized)
{
InitializeVBMethods(originalMethodInfo.DeclaringType.Assembly);
s_vbMethodsInitialized = true;
}
// try again
return s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo);
}
}
suggestedMethodInfo = null;
return false;
}
///
/// Initializes the dictionary of alternative methods.
/// Currently, it simply initializes an empty dictionary.
///
///
private static Dictionary InitializeAlternateMethodInfos()
{
return new Dictionary(1);
}
///
/// Populates the dictionary of alternative methods with the VB methods
///
///
private static void InitializeVBMethods(Assembly vbAssembly)
{
Debug.Assert(!s_vbMethodsInitialized);
//Handle { Mid(arg1, ar2), Mid(arg1, arg2, arg3) }
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
s_alternativeMethods.Add(
stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null),
stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int), typeof(int) }, null));
}
#endregion
}
private sealed class CanonicalFunctionDefaultTranslator : CallTranslator
{
internal CanonicalFunctionDefaultTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
//Math functions
yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
//Decimal functions
yield return typeof(Decimal).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Decimal).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
//String functions
yield return typeof(String).GetMethod("Replace", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(String), typeof(String) }, null);
yield return typeof(String).GetMethod("ToLower", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
yield return typeof(String).GetMethod("ToUpper", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
yield return typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
}
// Default translator for method calls into canonical functions.
// Translation:
// MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn)
// this.MethodName(arg1, arg2, .., argn) -> MethodName(this, arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
LinqExpression[] linqArguments;
if (!call.Method.IsStatic)
{
Debug.Assert(call.Object != null, "Instance method without this");
List arguments = new List(call.Arguments.Count + 1);
arguments.Add(call.Object);
arguments.AddRange(call.Arguments);
linqArguments = arguments.ToArray();
}
else
{
linqArguments = call.Arguments.ToArray();
}
return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, linqArguments);
}
}
#region System.String Method Translators
private sealed class ContainsTranslator : CallTranslator
{
internal ContainsTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Contains", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.Contains(argument) -> IndexOf(argument, object) > 0
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateGreaterThanExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(0));
return comparisonExpression;
}
}
private sealed class IndexOfTranslator : CallTranslator
{
internal IndexOfTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("IndexOf", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// IndexOf(arg1) -> IndexOf(arg1, this) - 1
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IndexOf");
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
CqtExpression minusExpression = parent._commandTree.CreateMinusExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(1));
return minusExpression;
}
}
private sealed class StartsWithTranslator : CallTranslator
{
internal StartsWithTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("StartsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.StartsWith(argument) -> IndexOf(argument, object) == 1
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(1));
return comparisonExpression;
}
}
private sealed class EndsWithTranslator : CallTranslator
{
internal EndsWithTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("EndsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.EndsWith(argument) -> Right(object, Length(argument)) == argument
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression lengthExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]);
DbExpression rightExpression = parent.CreateCanonicalFunction(ExpressionConverter.Right, call,
parent.TranslateExpression(call.Object),
lengthExpression);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(
rightExpression,
parent.TranslateExpression(call.Arguments[0]));
return comparisonExpression;
}
}
private sealed class SubstringTranslator : CallTranslator
{
internal SubstringTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null);
yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null);
}
// Translation:
// Substring(arg1) -> Substring(this, arg1+1, Length(this) - arg1))
// Substring(arg1, arg2) -> Substring(this, arg1+1, arg2)
//
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Substring");
CqtExpression target = parent.TranslateExpression(call.Object);
CqtExpression fromIndex = parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent._commandTree.CreateConstantExpression(1));
CqtExpression length;
if (call.Arguments.Count == 1)
{
length = parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent.TranslateExpression(call.Arguments[0]));
}
else
{
length = parent.TranslateExpression(call.Arguments[1]);
}
CqtExpression substringExpression = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, target, fromIndex, length);
return substringExpression;
}
}
private sealed class RemoveTranslator : CallTranslator
{
internal RemoveTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null);
yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null);
}
// Translation:
// Remove(arg1) -> Substring(this, 1, arg1)
// Remove(arg1, arg2) -> Concat(Substring(this, 1, arg1) , Substring(this, arg1 + arg2 + 1, Length(this) - (arg1 + arg2)))
// Remove(arg1, arg2) is only supported if arg2 is a non-negative integer
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Remove");
//Substring(this, 1, arg1)
CqtExpression result =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreateConstantExpression(1),
parent.TranslateExpression(call.Arguments[0]));
//Concat(result, Substring(this, (arg1 + arg2) +1, Length(this) - (arg1 + arg2)))
if (call.Arguments.Count == 2)
{
//If there are two arguemtns, we only support cases when the second one translates to a non-negative constant
CqtExpression translatedArgument1 = parent.TranslateExpression(call.Arguments[1]);
if (!IsNonNegativeIntegerConstant(translatedArgument1))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedStringRemoveCase(call.Method, call.Method.GetParameters()[1].Name));
}
// Build the second substring
// (arg1 + arg2) +1
CqtExpression substringStartIndex =
parent._commandTree.CreatePlusExpression(
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
translatedArgument1),
parent._commandTree.CreateConstantExpression(1));
// Length(this) - (arg1 + arg2)
CqtExpression substringLength =
parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent.TranslateExpression(call.Arguments[1])));
// Substring(this, substringStartIndex, substringLenght)
CqtExpression secondSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
substringStartIndex,
substringLength);
// result = Concat (result, secondSubstring)
result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, secondSubstring);
}
return result;
}
private static bool IsNonNegativeIntegerConstant(CqtExpression argument)
{
// Check whether it is a constant of type Int32
if (argument.ExpressionKind != DbExpressionKind.Constant ||
!TypeSemantics.IsPrimitiveType(argument.ResultType, PrimitiveTypeKind.Int32))
{
return false;
}
// Check whether its value is non-negative
DbConstantExpression constantExpression = (DbConstantExpression)argument;
int value = (int)constantExpression.Value;
if (value < 0)
{
return false;
}
return true;
}
}
private sealed class InsertTranslator : CallTranslator
{
internal InsertTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Insert", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(string) }, null);
}
// Translation:
// Insert(startIndex, value) -> Concat(Concat(Substring(this, 1, startIndex), value), Substring(this, startIndex+1, Length(this) - startIndex))
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 2, "Expecting 2 arguments for String.Insert");
//Substring(this, 1, startIndex)
CqtExpression firstSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreateConstantExpression(1),
parent.TranslateExpression(call.Arguments[0]));
//Substring(this, startIndex+1, Length(this) - startIndex)
CqtExpression secondSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent._commandTree.CreateConstantExpression(1)),
parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent.TranslateExpression(call.Arguments[0])));
// result = Concat( Concat (firstSubstring, value), secondSubstring )
CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
firstSubstring,
parent.TranslateExpression(call.Arguments[1])),
secondSubstring);
return result;
}
}
private sealed class IsNullOrEmptyTranslator : CallTranslator
{
internal IsNullOrEmptyTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("IsNullOrEmpty", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
}
// Translation:
// IsNullOrEmpty(value) -> (IsNull(value)) OR Length(value) = 0
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IsNullOrEmpty");
//IsNull(value)
CqtExpression isNullExpression =
parent._commandTree.CreateIsNullExpression(
parent.TranslateExpression(call.Arguments[0]));
//Length(value) = 0
CqtExpression emptyStringExpression =
parent._commandTree.CreateEqualsExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]),
parent._commandTree.CreateConstantExpression(0));
CqtExpression result = parent._commandTree.CreateOrExpression(isNullExpression, emptyStringExpression);
return result;
}
}
private sealed class StringConcatTranslator : CallTranslator
{
internal StringConcatTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string) }, null);
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string) }, null);
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null);
}
// Translation:
// Concat (arg1, arg2) -> Concat(arg1, arg2)
// Concat (arg1, arg2, arg3) -> Concat(Concat(arg1, arg2), arg3)
// Concat (arg1, arg2, arg3, arg4) -> Concat(Concat(Concat(arg1, arg2), arg3), arg4)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count >= 2 && call.Arguments.Count <= 4, "Expecting between 2 and 4 arguments for String.Concat");
CqtExpression result = parent.TranslateExpression(call.Arguments[0]);
for (int argIndex = 1; argIndex < call.Arguments.Count; argIndex++)
{
// result = Concat(result, arg[argIndex])
result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
result,
parent.TranslateExpression(call.Arguments[argIndex]));
}
return result;
}
}
private abstract class TrimStartTrimEndBaseTranslator : CallTranslator
{
private string _canonicalFunctionName;
protected TrimStartTrimEndBaseTranslator(IEnumerable methods, string canonicalFunctionName)
: base(methods)
{
_canonicalFunctionName = canonicalFunctionName;
}
// Translation:
// object.MethodName -> CanonicalFunctionName(object)
// Supported only if the argument is an empty array.
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (!IsEmptyArray(call.Arguments[0]))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedTrimStartTrimEndCase(call.Method));
}
return parent.TranslateIntoCanonicalFunction(_canonicalFunctionName, call, call.Object);
}
internal static bool IsEmptyArray(LinqExpression expression)
{
if (expression.NodeType != ExpressionType.NewArrayInit)
{
return false;
}
NewArrayExpression newArray = (NewArrayExpression)expression;
if (newArray.Expressions.Count != 0)
{
return false;
}
return true;
}
}
private sealed class TrimStartTranslator : TrimStartTrimEndBaseTranslator
{
internal TrimStartTranslator()
: base(GetMethods(), ExpressionConverter.LTrim) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("TrimStart", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null);
}
}
private sealed class TrimEndTranslator : TrimStartTrimEndBaseTranslator
{
internal TrimEndTranslator()
: base(GetMethods(), ExpressionConverter.RTrim) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("TrimEnd", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null);
}
}
#endregion
#region Visual Basic Specific Translators
private sealed class VBCanonicalFunctionDefaultTranslator : CallTranslator
{
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime";
internal VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
private static IEnumerable GetMethods(Assembly vbAssembly)
{
//Strings Types
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
yield return stringsType.GetMethod("Trim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("LTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("RTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("Left", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null);
yield return stringsType.GetMethod("Right", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null);
//DateTimeType
Type dateTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName);
yield return dateTimeType.GetMethod("Year", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Month", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Day", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Hour", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Minute", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Second", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
}
// Default translator for vb static method calls into canonical functions.
// Translation:
// MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, call.Arguments.ToArray());
}
}
private sealed class VBCanonicalFunctionRenameTranslator : CallTranslator
{
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
private static readonly Dictionary s_methodNameMap = new Dictionary(4);
internal VBCanonicalFunctionRenameTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
private static IEnumerable GetMethods(Assembly vbAssembly)
{
//Strings Types
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
yield return GetMethod(stringsType, "Len", ExpressionConverter.Length, new Type[] { typeof(string) });
yield return GetMethod(stringsType, "Mid", ExpressionConverter.Substring, new Type[] { typeof(string), typeof(int), typeof(int) });
yield return GetMethod(stringsType, "UCase", ExpressionConverter.ToUpper, new Type[] { typeof(string) });
yield return GetMethod(stringsType, "LCase", ExpressionConverter.ToLower, new Type[] { typeof(string) });
}
private static MethodInfo GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes)
{
MethodInfo methodInfo = declaringType.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static, null, argumentTypes, null);
s_methodNameMap.Add(methodInfo, canonicalFunctionName);
return methodInfo;
}
// Translator for static method calls into canonical functions when only the name of the canonical function
// is different from the name of the method, but the argumens match.
// Translation:
// MethodName(arg1, arg2, .., argn) -> CanonicalFunctionName(arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return parent.TranslateIntoCanonicalFunction(s_methodNameMap[call.Method], call, call.Arguments.ToArray());
}
}
private sealed class VBDatePartTranslator : CallTranslator
{
private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime";
private const string s_DateIntervalFullName = "Microsoft.VisualBasic.DateInterval";
private const string s_FirstDayOfWeekFullName = "Microsoft.VisualBasic.FirstDayOfWeek";
private const string s_FirstWeekOfYearFullName = "Microsoft.VisualBasic.FirstWeekOfYear";
private static HashSet s_supportedIntervals;
internal VBDatePartTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
static VBDatePartTranslator()
{
s_supportedIntervals = new HashSet();
s_supportedIntervals.Add(ExpressionConverter.Year);
s_supportedIntervals.Add(ExpressionConverter.Month);
s_supportedIntervals.Add(ExpressionConverter.Day);
s_supportedIntervals.Add(ExpressionConverter.Hour);
s_supportedIntervals.Add(ExpressionConverter.Minute);
s_supportedIntervals.Add(ExpressionConverter.Second);
}
private static IEnumerable GetMethods(Assembly vbAssembly)
{
Type dateAndTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName);
Type dateIntervalEnum = vbAssembly.GetType(s_DateIntervalFullName);
Type firstDayOfWeekEnum = vbAssembly.GetType(s_FirstDayOfWeekFullName);
Type firstWeekOfYearEnum = vbAssembly.GetType(s_FirstWeekOfYearFullName);
yield return dateAndTimeType.GetMethod("DatePart", BindingFlags.Public | BindingFlags.Static, null,
new Type[] { dateIntervalEnum, typeof(DateTime), firstDayOfWeekEnum, firstWeekOfYearEnum }, null);
}
// Translation:
// DatePart(DateInterval, date, arg3, arg4) -> 'DateInterval'(date)
// Note: it is only supported for the values of DateInterval listed in s_supportedIntervals.
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 4, "Expecting 4 arguments for Microsoft.VisualBasic.DateAndTime.DatePart");
ConstantExpression intervalLinqExpression = call.Arguments[0] as ConstantExpression;
if (intervalLinqExpression == null)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartNonConstantInterval(call.Method, call.Method.GetParameters()[0].Name));
}
string intervalValue = intervalLinqExpression.Value.ToString();
if (!s_supportedIntervals.Contains(intervalValue))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartInvalidInterval(call.Method, call.Method.GetParameters()[0].Name, intervalValue));
}
CqtExpression result = parent.TranslateIntoCanonicalFunction(intervalValue, call, call.Arguments[1]);
return result;
}
}
#endregion
#endregion
#region Sequence method translators
private abstract class SequenceMethodTranslator
{
private readonly IEnumerable _methods;
protected SequenceMethodTranslator(params SequenceMethod[] methods) { _methods = methods; }
internal IEnumerable Methods { get { return _methods; } }
internal virtual CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)
{
return Translate(parent, call);
}
internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call);
public override string ToString()
{
return GetType().Name;
}
}
private abstract class PagingTranslator : UnarySequenceMethodTranslator
{
protected PagingTranslator(params SequenceMethod[] methods) : base(methods) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// translate count expression
Debug.Assert(call.Arguments.Count == 2, "Skip and Take must have 2 arguments");
LinqExpression linqCount = call.Arguments[1];
CqtExpression count = parent.TranslateExpression(linqCount);
// remove projections at the root of the expression and then reapply after apply paging operator
DbProjectExpression projection = null;
if (operand.ExpressionKind == DbExpressionKind.Project)
{
projection = (DbProjectExpression)operand;
operand = projection.Input.Expression;
}
// translate paging expression
DbExpression result = TranslatePagingOperator(parent, operand, count);
// reapply project as necessary
if (null != projection)
{
projection.Input.Expression = result;
result = projection;
}
return result;
}
protected abstract CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count);
}
private sealed class TakeTranslator : PagingTranslator
{
internal TakeTranslator() : base(SequenceMethod.Take) { }
protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)
{
return parent.Limit(operand, count);
}
}
private sealed class SkipTranslator : PagingTranslator
{
internal SkipTranslator() : base(SequenceMethod.Skip) { }
protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)
{
// skip requires a sorted input
if (operand.ExpressionKind != DbExpressionKind.Sort)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_SkipWithoutOrder);
}
DbSortExpression sortedOperand = (DbSortExpression)operand;
Span sortedSpan = null;
bool hadSpan = parent.TryGetSpan(sortedOperand, out sortedSpan);
// generate a skip statement with the sort order of the original sorted input.
DbSkipExpression skip = parent.Skip(sortedOperand.Input, sortedOperand.SortOrder, count);
// If the original DbSortExpression had Span information, then this is applied
// to the newly created DbSkipExpression before returning.
if (hadSpan)
{
parent.AddSpanMapping(skip, sortedSpan);
}
return skip;
}
}
private sealed class SingleTranslatorNotSupported : SequenceMethodTranslator
{
internal SingleTranslatorNotSupported()
: base(SequenceMethod.Single, SequenceMethod.SinglePredicate,
SequenceMethod.SingleOrDefault, SequenceMethod.SingleOrDefaultPredicate) { }
internal override System.Data.Common.CommandTrees.DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedSingle);
}
}
private sealed class JoinTranslator : SequenceMethodTranslator
{
internal JoinTranslator() : base(SequenceMethod.Join) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(5 == call.Arguments.Count);
// get expressions describing inputs to the join
CqtExpression outer = parent.TranslateSet(call.Arguments[0]);
CqtExpression inner = parent.TranslateSet(call.Arguments[1]);
// get expressions describing key selectors
LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2);
LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3);
// translator key selectors
DbExpressionBinding outerBinding;
DbExpressionBinding innerBinding;
CqtExpression outerKeySelector = parent.TranslateLambda(outerLambda, outer, out outerBinding);
CqtExpression innerKeySelector = parent.TranslateLambda(innerLambda, inner, out innerBinding);
// construct join expression
if (!TypeSemantics.IsEqualComparable(outerKeySelector.ResultType) ||
!TypeSemantics.IsEqualComparable(innerKeySelector.ResultType))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
DbJoinExpression join = parent._commandTree.CreateInnerJoinExpression(
outerBinding, innerBinding,
parent.CreateEqualsExpression(outerKeySelector, innerKeySelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type));
DbExpressionBinding joinBinding = parent._commandTree.CreateExpressionBinding(join);
// get selector expression
LambdaExpression selectorLambda = parent.GetLambdaExpression(call, 4);
// create property expressions for the inner and outer
DbPropertyExpression joinOuter = parent._commandTree.CreatePropertyExpression(
outerBinding.VariableName, joinBinding.Variable);
DbPropertyExpression joinInner = parent._commandTree.CreatePropertyExpression(
innerBinding.VariableName, joinBinding.Variable);
// push outer and inner join parts into the binding scope (the order
// is irrelevant because the binding context matches based on parameter
// reference rather than ordinal)
parent._bindingContext.PushBindingScope(
new Binding(selectorLambda.Parameters[0], joinOuter),
new Binding(selectorLambda.Parameters[1], joinInner));
// translate join selector
CqtExpression selector = parent.TranslateExpression(selectorLambda.Body);
// pop binding scope
parent._bindingContext.PopBindingScope();
return parent._commandTree.CreateProjectExpression(joinBinding, selector);
}
}
private abstract class BinarySequenceMethodTranslator : SequenceMethodTranslator
{
protected BinarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (null != call.Object)
{
// instance method
Debug.Assert(1 == call.Arguments.Count);
CqtExpression left = parent.TranslateSet(call.Object);
CqtExpression right = parent.TranslateSet(call.Arguments[0]);
return TranslateBinary(parent, left, right);
}
else
{
// static extension method
Debug.Assert(2 == call.Arguments.Count);
CqtExpression left = parent.TranslateSet(call.Arguments[0]);
CqtExpression right = parent.TranslateSet(call.Arguments[1]);
return TranslateBinary(parent, left, right);
}
}
protected abstract CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right);
}
private class ConcatTranslator : BinarySequenceMethodTranslator
{
internal ConcatTranslator() : base(SequenceMethod.Concat) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.UnionAll(left, right);
}
}
private sealed class UnionTranslator : BinarySequenceMethodTranslator
{
internal UnionTranslator() : base(SequenceMethod.Union) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Distinct(parent.UnionAll(left, right));
}
}
private sealed class IntersectTranslator : BinarySequenceMethodTranslator
{
internal IntersectTranslator() : base(SequenceMethod.Intersect) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Intersect(left, right);
}
}
private sealed class ExceptTranslator : BinarySequenceMethodTranslator
{
internal ExceptTranslator() : base(SequenceMethod.Except) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Except(left, right);
}
}
private abstract class AggregateTranslator : SequenceMethodTranslator
{
private readonly string _functionName;
private readonly bool _takesPredicate;
protected AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods)
: base(methods)
{
_takesPredicate = takesPredicate;
_functionName = functionName;
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
bool isUnary = 1 == call.Arguments.Count;
Debug.Assert(isUnary || 2 == call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Arguments[0]);
// Facet information for the return type cannot help in determining the appropriate function overload
// since no constraints on the return value are known.
TypeUsage returnType = parent.GetValueLayerType(call.Type);
LambdaExpression lambda = null;
//We save the original operand for the optimized translation
CqtExpression originalOperand = operand;
if (!isUnary)
{
lambda = parent.GetLambdaExpression(call, 1);
DbExpressionBinding sourceBinding;
CqtExpression cqtLambda = parent.TranslateLambda(lambda, operand, out sourceBinding);
if (_takesPredicate)
{
// treat the lambda as a filter
operand = parent.Filter(sourceBinding, cqtLambda);
}
else
{
// treat the lambda as a selector
operand = parent._commandTree.CreateProjectExpression(sourceBinding, cqtLambda);
}
}
operand = WrapCollectionOperand(parent, operand, returnType);
DbGroupExpressionBinding operandBinding = parent._commandTree.CreateGroupExpressionBinding(operand);
EdmFunction function = FindFunction(parent, call, returnType);
// create aggregate
List> keys = new List>(0); // no key
List> aggregates = new List>(1);
const string aggregateName = "Aggregate";
aggregates.Add(new KeyValuePair(aggregateName, // name is arbitrary (there is only one in this context)
parent._commandTree.CreateFunctionAggregate(function, operandBinding.GroupVariable)));
DbGroupByExpression aggregate = parent._commandTree.CreateGroupByExpression(
operandBinding, keys, aggregates);
DbExpressionBinding aggregateBinding = parent._commandTree.CreateExpressionBinding(aggregate);
// project result
DbPropertyExpression property = parent._commandTree.CreatePropertyExpression(
aggregateName, aggregateBinding.Variable);
DbProjectExpression projection = parent._commandTree.CreateProjectExpression(
aggregateBinding, parent.AlignTypes(property, call.Type));
// return a single element to represent the projection
DbElementExpression element = parent._commandTree.CreateElementExpression(projection);
// Try to create and log an optimized translation
TryCreateOptimizedTranslation(parent, lambda, originalOperand, function, element);
return element;
}
// If the function is over a group by, it tries to incorporate the aggregate function into the group by.
// If it does, it gives the aggregate an alias and it returns it through
private void TryCreateOptimizedTranslation(ExpressionConverter parent, LambdaExpression lambda, CqtExpression operand, EdmFunction function, CqtExpression originalTranslation)
{
//Aggregates that take predicates as arguments cannot be incorporated into a group by
if (_takesPredicate && (lambda != null))
{
return;
}
//Check whether the operand is a property over an output of grouping
if (operand.ExpressionKind != DbExpressionKind.Property)
{
return;
}
DbPropertyExpression propertyExpression = (DbPropertyExpression)operand;
if (propertyExpression.Instance.ExpressionKind != DbExpressionKind.VariableReference)
{
return;
}
DbVariableReferenceExpression inputVarRef = (DbVariableReferenceExpression)propertyExpression.Instance;
//If the input corresponding to the var ref has an optimized translation, generate an alternate translation for this as well.
CqtExpression input;
if (!parent._variableNameToInputExpression.TryGetValue(inputVarRef.VariableName, out input))
{
return;
}
DbGroupByTemplate optimizedTranslationOfInput;
if (!parent._groupByDefaultToOptimizedTranslationMap.TryGetValue(input, out optimizedTranslationOfInput))
{
return;
}
Debug.Assert(TypeSemantics.IsCollectionType(function.Parameters[0].TypeUsage), "Aggregates should always have collection arguments");
TypeUsage elementType = TypeHelpers.GetElementTypeUsage(function.Parameters[0].TypeUsage);
// The aggregate can be added to the list of aggregates.
CqtExpression aggregateArgument = optimizedTranslationOfInput.Input.GroupVariable;
if (lambda != null)
{
aggregateArgument = parent.TranslateLambda(lambda, aggregateArgument);
}
string optimizedTranslationAlias = String.Format(CultureInfo.InvariantCulture, "Aggregate{0}", optimizedTranslationOfInput.Aggregates.Count);
optimizedTranslationOfInput.Aggregates.Add(new KeyValuePair(optimizedTranslationAlias,
parent._commandTree.CreateFunctionAggregate(function,
WrapNonCollectionOperand(parent, aggregateArgument, elementType))));
//log the alias
parent._aggregateDefaultTranslationToOptimizedTranslationInfoMap.Add(originalTranslation, new KeyValuePair(optimizedTranslationOfInput, optimizedTranslationAlias));
}
// If necessary, wraps the operand to ensure the appropriate aggregate overload is called
protected virtual CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand,
TypeUsage returnType)
{
// check if the operand needs to be wrapped to ensure the correct function overload is called
if (!TypeUsageEquals(returnType, ((CollectionType)operand.ResultType.EdmType).TypeUsage))
{
DbExpressionBinding operandCastBinding = parent._commandTree.CreateExpressionBinding(operand);
DbProjectExpression operandCastProjection = parent._commandTree.CreateProjectExpression(
operandCastBinding, parent._commandTree.CreateCastExpression(operandCastBinding.Variable, returnType));
operand = operandCastProjection;
}
return operand;
}
// If necessary, wraps the operand to ensure the appropriate aggregate overload is called
protected virtual CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand,
TypeUsage returnType)
{
if (!TypeUsageEquals(returnType, operand.ResultType))
{
operand = parent._commandTree.CreateCastExpression(operand, returnType);
}
return operand;
}
// Finds the best function overload given the expected return type
protected virtual EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call,
TypeUsage argumentType)
{
List argTypes = new List(1);
// In general, we use the return type as the parameter type to align LINQ semantics
// with SQL semantics, and avoid apparent loss of precision for some LINQ aggregate operators.
// (e.g., AVG(1, 2) = 2.0, AVG((double)1, (double)2)) = 1.5)
argTypes.Add(argumentType);
return parent.FindCanonicalFunction(_functionName, argTypes, true /* isGroupAggregateFunction */, call);
}
}
private sealed class MaxTranslator : AggregateTranslator
{
internal MaxTranslator()
: base("MAX", false,
SequenceMethod.Max,
SequenceMethod.MaxSelector,
SequenceMethod.MaxInt,
SequenceMethod.MaxIntSelector,
SequenceMethod.MaxDecimal,
SequenceMethod.MaxDecimalSelector,
SequenceMethod.MaxDouble,
SequenceMethod.MaxDoubleSelector,
SequenceMethod.MaxLong,
SequenceMethod.MaxLongSelector,
SequenceMethod.MaxSingle,
SequenceMethod.MaxSingleSelector,
SequenceMethod.MaxNullableDecimal,
SequenceMethod.MaxNullableDecimalSelector,
SequenceMethod.MaxNullableDouble,
SequenceMethod.MaxNullableDoubleSelector,
SequenceMethod.MaxNullableInt,
SequenceMethod.MaxNullableIntSelector,
SequenceMethod.MaxNullableLong,
SequenceMethod.MaxNullableLongSelector,
SequenceMethod.MaxNullableSingle,
SequenceMethod.MaxNullableSingleSelector)
{
}
}
private sealed class MinTranslator : AggregateTranslator
{
internal MinTranslator()
: base("MIN", false,
SequenceMethod.Min,
SequenceMethod.MinSelector,
SequenceMethod.MinDecimal,
SequenceMethod.MinDecimalSelector,
SequenceMethod.MinDouble,
SequenceMethod.MinDoubleSelector,
SequenceMethod.MinInt,
SequenceMethod.MinIntSelector,
SequenceMethod.MinLong,
SequenceMethod.MinLongSelector,
SequenceMethod.MinNullableDecimal,
SequenceMethod.MinSingle,
SequenceMethod.MinSingleSelector,
SequenceMethod.MinNullableDecimalSelector,
SequenceMethod.MinNullableDouble,
SequenceMethod.MinNullableDoubleSelector,
SequenceMethod.MinNullableInt,
SequenceMethod.MinNullableIntSelector,
SequenceMethod.MinNullableLong,
SequenceMethod.MinNullableLongSelector,
SequenceMethod.MinNullableSingle,
SequenceMethod.MinNullableSingleSelector)
{
}
}
private sealed class AverageTranslator : AggregateTranslator
{
internal AverageTranslator()
: base("AVG", false,
SequenceMethod.AverageDecimal,
SequenceMethod.AverageDecimalSelector,
SequenceMethod.AverageDouble,
SequenceMethod.AverageDoubleSelector,
SequenceMethod.AverageInt,
SequenceMethod.AverageIntSelector,
SequenceMethod.AverageLong,
SequenceMethod.AverageLongSelector,
SequenceMethod.AverageSingle,
SequenceMethod.AverageSingleSelector,
SequenceMethod.AverageNullableDecimal,
SequenceMethod.AverageNullableDecimalSelector,
SequenceMethod.AverageNullableDouble,
SequenceMethod.AverageNullableDoubleSelector,
SequenceMethod.AverageNullableInt,
SequenceMethod.AverageNullableIntSelector,
SequenceMethod.AverageNullableLong,
SequenceMethod.AverageNullableLongSelector,
SequenceMethod.AverageNullableSingle,
SequenceMethod.AverageNullableSingleSelector)
{
}
}
private sealed class SumTranslator : AggregateTranslator
{
internal SumTranslator()
: base("SUM", false,
SequenceMethod.SumDecimal,
SequenceMethod.SumDecimalSelector,
SequenceMethod.SumDouble,
SequenceMethod.SumDoubleSelector,
SequenceMethod.SumInt,
SequenceMethod.SumIntSelector,
SequenceMethod.SumLong,
SequenceMethod.SumLongSelector,
SequenceMethod.SumSingle,
SequenceMethod.SumSingleSelector,
SequenceMethod.SumNullableDecimal,
SequenceMethod.SumNullableDecimalSelector,
SequenceMethod.SumNullableDouble,
SequenceMethod.SumNullableDoubleSelector,
SequenceMethod.SumNullableInt,
SequenceMethod.SumNullableIntSelector,
SequenceMethod.SumNullableLong,
SequenceMethod.SumNullableLongSelector,
SequenceMethod.SumNullableSingle,
SequenceMethod.SumNullableSingleSelector)
{
}
}
private abstract class CountTranslatorBase : AggregateTranslator
{
protected CountTranslatorBase(string functionName, params SequenceMethod[] methods)
: base(functionName, true, methods)
{
}
protected override CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)
{
// always count a constant value
DbExpressionBinding operandBinding = parent._commandTree.CreateExpressionBinding(operand);
DbProjectExpression constantProject = parent._commandTree.CreateProjectExpression(
operandBinding, parent._commandTree.CreateTrueExpression());
return constantProject;
}
protected override CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)
{
// always count a constant value
DbExpression constantExpression = parent._commandTree.CreateConstantExpression(1);
if (!TypeUsageEquals(constantExpression.ResultType, returnType))
{
constantExpression = parent._commandTree.CreateCastExpression(constantExpression, returnType);
}
return constantExpression;
}
protected override EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call,
TypeUsage argumentType)
{
// For most ELinq aggregates, the argument type is the return type. For "count", the
// argument type is always Boolean, since we project a constant Boolean value in WrapCollectionOperand.
TypeUsage booleanTypeUsage = parent._commandTree.CreateTrueExpression().ResultType;
return base.FindFunction(parent, call, booleanTypeUsage);
}
}
private sealed class CountTranslator : CountTranslatorBase
{
internal CountTranslator()
: base("COUNT", SequenceMethod.Count, SequenceMethod.CountPredicate)
{
}
}
private sealed class LongCountTranslator : CountTranslatorBase
{
internal LongCountTranslator()
: base("BIGCOUNT", SequenceMethod.LongCount, SequenceMethod.LongCountPredicate)
{
}
}
private abstract class UnarySequenceMethodTranslator : SequenceMethodTranslator
{
protected UnarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (null != call.Object)
{
// instance method
Debug.Assert(0 <= call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Object);
return TranslateUnary(parent, operand, call);
}
else
{
// static extension method
Debug.Assert(1 <= call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Arguments[0]);
return TranslateUnary(parent, operand, call);
}
}
protected abstract CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call);
}
private sealed class PassthroughTranslator : UnarySequenceMethodTranslator
{
internal PassthroughTranslator() : base(SequenceMethod.AsQueryableGeneric, SequenceMethod.AsQueryable, SequenceMethod.AsEnumerable) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// make sure the operand has collection type to avoid treating (for instance) String as a
// sub-query
if (TypeSemantics.IsCollectionType(operand.ResultType))
{
return operand;
}
else
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedPassthrough(
call.Method.Name, operand.ResultType.EdmType.Name));
}
}
}
private sealed class OfTypeTranslator : UnarySequenceMethodTranslator
{
internal OfTypeTranslator() : base(SequenceMethod.OfType) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
Type clrType = call.Method.GetGenericArguments()[0];
TypeUsage modelType;
// If the model type does not exist in the perspective or is not either an EntityType
// or a ComplexType, fail - OfType() is not a valid operation on scalars,
// enumerations, collections, etc.
if (!parent.TryGetValueLayerType(clrType, out modelType) ||
!(TypeSemantics.IsEntityType(modelType) || TypeSemantics.IsComplexType(modelType)))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_InvalidOfTypeResult(DescribeClrType(clrType)));
}
// Create an of type expression to filter the original query to include
// only those results that are of the specified type.
CqtExpression ofTypeExpression = parent.OfType(operand, modelType);
return ofTypeExpression;
}
}
private sealed class DistinctTranslator : UnarySequenceMethodTranslator
{
internal DistinctTranslator() : base(SequenceMethod.Distinct) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
return parent.Distinct(operand);
}
}
private sealed class AnyTranslator : UnarySequenceMethodTranslator
{
internal AnyTranslator() : base(SequenceMethod.Any) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
// "Any" is equivalent to "exists".
return parent._commandTree.CreateNotExpression(
parent._commandTree.CreateIsEmptyExpression(operand));
}
}
private abstract class OneLambdaTranslator : SequenceMethodTranslator
{
internal OneLambdaTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
CqtExpression source;
DbExpressionBinding sourceBinding;
CqtExpression lambda;
return Translate(parent, call, out source, out sourceBinding, out lambda);
}
// Helper method for tranlsation
protected CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda)
{
Debug.Assert(2 <= call.Arguments.Count);
// translate source
source = parent.TranslateExpression(call.Arguments[0]);
// translate lambda expression
LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1);
lambda = parent.TranslateLambda(lambdaExpression, source, out sourceBinding);
return TranslateOneLambda(parent, sourceBinding, lambda);
}
protected abstract CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda);
}
private sealed class AnyPredicateTranslator : OneLambdaTranslator
{
internal AnyPredicateTranslator() : base(SequenceMethod.AnyPredicate) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent._commandTree.CreateAnyExpression(sourceBinding, lambda);
}
}
private sealed class AllTranslator : OneLambdaTranslator
{
internal AllTranslator() : base(SequenceMethod.All) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent._commandTree.CreateAllExpression(sourceBinding, lambda);
}
}
private sealed class WhereTranslator : OneLambdaTranslator
{
internal WhereTranslator() : base(SequenceMethod.Where) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Filter(sourceBinding, lambda);
}
}
private sealed class SelectTranslator : OneLambdaTranslator
{
internal SelectTranslator() : base(SequenceMethod.Select) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
CqtExpression source;
DbExpressionBinding sourceBinding;
CqtExpression lambda;
CqtExpression result = Translate(parent, call, out source, out sourceBinding, out lambda);
// If the select if over a GroupBy, check whether an optimized translation can be produced
// The default translation is to 'manually' do group by, the optimized is translating into a
// DbGroupByExpression
CqtExpression rewrittenExpression;
if (parent.TryRewrite(source, sourceBinding, lambda, out rewrittenExpression))
{
result = rewrittenExpression;
}
return result;
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Project(sourceBinding, lambda);
}
}
private abstract class FirstTranslatorBase : UnarySequenceMethodTranslator
{
protected FirstTranslatorBase(bool orDefault, params SequenceMethod[] methods)
: base(methods)
{
_orDefault = orDefault;
}
private readonly bool _orDefault;
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// Apply a Limit(1) expression to restrict the result to a single element
// (this returns an empty set if the input set is initially empty).
CqtExpression result = parent.Limit(operand, parent._commandTree.CreateConstantExpression(1));
// If this First() or FirstOrDefault() operation is the root of the query,
// then the evaluation is performed in the client over the resulting set,
// to provide the same semantics as Linq to Objects. Otherwise, an Element
// expression is applied to retrieve the single element (or null, if empty)
// from the output set.
if (!parent.IsQueryRoot(call))
{
if (_orDefault)
{
result = parent._commandTree.CreateElementExpression(result);
result = AddDefaultCase(parent, result, call.Type);
}
else
{
// First is not allowed anywhere other than the root of the query,
// since First() operations logically executed in the store will not
// throw an exception if the input set is empty (typically they will
// simply produce a null result).
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst);
}
}
// Span is preserved over First/FirstOrDefault with or without a predicate
Span inputSpan = null;
if (parent.TryGetSpan(operand, out inputSpan))
{
parent.AddSpanMapping(result, inputSpan);
}
return result;
}
internal static CqtExpression AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType)
{
// Retrieve default value.
object defaultValue = TypeSystem.GetDefaultValue(elementType);
if (null == defaultValue)
{
// Already null, which is the implicit default for DbElementExpression
return element;
}
// Otherwise, use the default value for the type
List whenExpressions = new List(1);
whenExpressions.Add(parent.CreateIsNullExpression(element, elementType));
List thenExpressions = new List(1);
thenExpressions.Add(parent._commandTree.CreateConstantExpression(defaultValue));
DbCaseExpression caseExpression = parent._commandTree.CreateCaseExpression(
whenExpressions, thenExpressions, element);
return caseExpression;
}
}
private sealed class FirstTranslator : FirstTranslatorBase
{
internal FirstTranslator() : base(false, SequenceMethod.First) { }
}
private sealed class FirstOrDefaultTranslator : FirstTranslatorBase
{
internal FirstOrDefaultTranslator() : base(true, SequenceMethod.FirstOrDefault) { }
}
private abstract class FirstPredicateTranslatorBase : OneLambdaTranslator
{
protected FirstPredicateTranslatorBase(bool orDefault, params SequenceMethod[] methods)
: base(methods)
{
_orDefault = orDefault;
}
private readonly bool _orDefault;
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// Convert the input set and the predicate into a filter expression
CqtExpression input = base.Translate(parent, call);
// If this First/FirstOrDefault is the root of the query,
// then the actual result will be produced by evaluated by
// calling First() or FirstOrDefault() on the filtered input set,
// which is limited to at most one element by applying a Limit(1) expression.
if (parent.IsQueryRoot(call))
{
// Calling ExpressionConverter.Limit propagates the Span.
return parent.Limit(input, parent._commandTree.CreateConstantExpression(1));
}
else
{
if (_orDefault)
{
CqtExpression element = parent._commandTree.CreateElementExpression(input);
element = FirstTranslatorBase.AddDefaultCase(parent, element, call.Type);
// Span is preserved over First/FirstOrDefault with or without a predicate
Span inputSpan = null;
if (parent.TryGetSpan(input, out inputSpan))
{
parent.AddSpanMapping(element, inputSpan);
}
return element;
}
else
{
// First (with or without a predicate) is not allowed anywhere other
// than the root of the query, since First(predicate) operations
// logically executed in the store will not throw an exception if the
// input set is empty (typically they will simply produce a null result).
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst);
}
}
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Filter(sourceBinding, lambda);
}
}
private sealed class FirstPredicateTranslator : FirstPredicateTranslatorBase
{
internal FirstPredicateTranslator() : base(false, SequenceMethod.FirstPredicate) { }
}
private sealed class FirstOrDefaultPredicateTranslator : FirstPredicateTranslatorBase
{
internal FirstOrDefaultPredicateTranslator() : base(true, SequenceMethod.FirstOrDefaultPredicate) { }
}
private sealed class SelectManyTranslator : OneLambdaTranslator
{
internal SelectManyTranslator() : base(SequenceMethod.SelectMany, SequenceMethod.SelectManyResultSelector) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
bool hasSelector = 3 == call.Arguments.Count;
CqtExpression crossApply = base.Translate(parent, call);
// perform a cross apply to implement the core logic for SelectMany (this translates the collection selector):
// SelectMany(i, Func> collectionSelector) =>
// i CROSS APPLY collectionSelector(i)
// The cross-apply yields a collection from which we yield either the right hand side (when
// no explicit resultSelector is given) or over which we apply the resultSelector Lambda expression.
DbExpressionBinding crossApplyBinding = parent._commandTree.CreateExpressionBinding(crossApply);
RowType crossApplyRowType = (RowType)(crossApplyBinding.Variable.ResultType.EdmType);
CqtExpression projectRight = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[1], crossApplyBinding.Variable);
CqtExpression resultProjection;
if (hasSelector)
{
CqtExpression projectLeft = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[0], crossApplyBinding.Variable);
LambdaExpression resultSelector = parent.GetLambdaExpression(call, 2);
// add the left and right projection terms to the binding context
parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[0], projectLeft),
new Binding(resultSelector.Parameters[1], projectRight));
// translate the result selector
resultProjection = parent.TranslateSet(resultSelector.Body);
// pop binding context
parent._bindingContext.PopBindingScope();
}
else
{
// project out the right hand side of the apply
resultProjection = projectRight;
}
// wrap result projection in project expression
return parent._commandTree.CreateProjectExpression(crossApplyBinding, resultProjection);
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
// elements of the inner selector should be used
lambda = parent.NormalizeSetSource(lambda);
DbExpressionBinding applyBinding = parent._commandTree.CreateExpressionBinding(lambda);
DbApplyExpression crossApply = parent._commandTree.CreateCrossApplyExpression(sourceBinding, applyBinding);
return crossApply;
}
}
private sealed class CastMethodTranslator : SequenceMethodTranslator
{
internal CastMethodTranslator() : base(SequenceMethod.Cast) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// Translate source
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
// Figure out the type to cast to
Type toClrType = TypeSystem.GetElementType(call.Type);
Type fromClrType = TypeSystem.GetElementType(call.Arguments[0].Type);
// Get binding to the elements of the input source
DbExpressionBinding binding = parent._commandTree.CreateExpressionBinding(source);
CqtExpression cast = parent.CreateCastExpression(binding.Variable, toClrType, fromClrType);
return parent._commandTree.CreateProjectExpression(binding, cast);
}
}
private sealed class GroupByTranslator : SequenceMethodTranslator
{
internal GroupByTranslator()
: base(SequenceMethod.GroupBy, SequenceMethod.GroupByElementSelector, SequenceMethod.GroupByElementSelectorResultSelector,
SequenceMethod.GroupByResultSelector)
{
}
// The default translation of GroupBy is:
// SELECT d as Key, (SELECT VALUE g FROM source WHERE source.Key = d) as Group
// FROM (SELECT DISTINCT source.Key)
//
// The optimized translation is simply creating a Cqt GroupByExpression
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)
{
// translate source
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
// translate key selector
LambdaExpression keySelectorLinq = parent.GetLambdaExpression(call, 1);
DbExpressionBinding sourceBinding;
CqtExpression keySelector = parent.TranslateLambda(keySelectorLinq, source, out sourceBinding);
// translate the key selector again in a different binding context (for the nested select)
DbExpressionBinding nestedSourceBinding;
CqtExpression nestedSelector = parent.TranslateLambda(keySelectorLinq, source, out nestedSourceBinding);
// create distinct expression
if (!TypeSemantics.IsEqualComparable(keySelector.ResultType))
{
// to avoid confusing error message about the "distinct" type, pre-emptively raise an exception
// about the group by key selector
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
CqtExpression distinct = parent.Distinct(
parent._commandTree.CreateProjectExpression(sourceBinding, keySelector));
DbExpressionBinding distinctBinding = parent._commandTree.CreateExpressionBinding(distinct);
// create group projection term
DbFilterExpression groupKeyFilter = parent.Filter(
nestedSourceBinding, parent.CreateEqualsExpression(nestedSelector, distinctBinding.Variable, EqualsPattern.PositiveNullEquality, keySelectorLinq.Type, keySelectorLinq.Type));
// interpret element selector if needed
CqtExpression selection = groupKeyFilter;
bool hasElementSelector = sequenceMethod == SequenceMethod.GroupByElementSelector ||
sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector;
if (hasElementSelector)
{
LambdaExpression elementSelectorLinq = parent.GetLambdaExpression(call, 2);
DbExpressionBinding elementSelectorSourceBinding;
CqtExpression elementSelector = parent.TranslateLambda(elementSelectorLinq, selection, out elementSelectorSourceBinding);
selection = parent._commandTree.CreateProjectExpression(elementSelectorSourceBinding,
elementSelector);
}
// create top level projection
List projectionTerms = new List(2);
projectionTerms.Add(distinctBinding.Variable);
projectionTerms.Add(selection);
// build projection type with initializer information
List properties = new List(2);
properties.Add(new EdmProperty(KeyColumnName, projectionTerms[0].ResultType));
properties.Add(new EdmProperty(GroupColumnName, projectionTerms[1].ResultType));
InitializerMetadata initializerMetadata = InitializerMetadata.CreateGroupingInitializer(
parent.EdmItemCollection, TypeSystem.GetElementType(call.Type));
RowType rowType = new RowType(properties, initializerMetadata);
TypeUsage rowTypeUsage = TypeUsage.Create(rowType);
CqtExpression topLevelProject = parent._commandTree.CreateProjectExpression(distinctBinding,
parent._commandTree.CreateNewInstanceExpression(rowTypeUsage, projectionTerms));
if (!hasElementSelector)
{
//Create optimized translation for the GroupBy - simple GroupBy template
DbGroupExpressionBinding groupByBinding;
CqtExpression newKeySelector = parent.TranslateLambda(keySelectorLinq, source, out groupByBinding);
DbGroupByTemplate groupByTemplate = new DbGroupByTemplate(groupByBinding);
groupByTemplate.GroupKeys.Add(new KeyValuePair(KeyColumnName, newKeySelector));
parent._groupByDefaultToOptimizedTranslationMap.Add(topLevelProject, groupByTemplate);
}
var result = topLevelProject;
// GroupBy may include a result selector; handle it
result = ProcessResultSelector(parent, call, sequenceMethod, topLevelProject, result);
return result;
}
private static DbExpression ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result)
{
// interpret result selector if needed
LambdaExpression resultSelectorLinqExpression = null;
if (sequenceMethod == SequenceMethod.GroupByResultSelector)
{
resultSelectorLinqExpression = parent.GetLambdaExpression(call, 2);
}
else if (sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector)
{
resultSelectorLinqExpression = parent.GetLambdaExpression(call, 3);
}
if (null != resultSelectorLinqExpression)
{
// selector maps (Key, Group) -> Result
// push bindings for key and group
DbExpressionBinding topLevelProjectBinding = parent._commandTree.CreateExpressionBinding(topLevelProject);
parent._variableNameToInputExpression.Add(topLevelProjectBinding.VariableName, topLevelProject);
DbPropertyExpression keyExpression = parent._commandTree.CreatePropertyExpression(
KeyColumnName, topLevelProjectBinding.Variable);
DbPropertyExpression groupExpression = parent._commandTree.CreatePropertyExpression(
GroupColumnName, topLevelProjectBinding.Variable);
parent._bindingContext.PushBindingScope(
new Binding(resultSelectorLinqExpression.Parameters[0], keyExpression),
new Binding(resultSelectorLinqExpression.Parameters[1], groupExpression));
// translate selector
CqtExpression resultSelector = parent.TranslateExpression(
resultSelectorLinqExpression.Body);
result = parent._commandTree.CreateProjectExpression(topLevelProjectBinding, resultSelector);
// see if the selector can be optimized
CqtExpression rewrittenExpression;
if (parent.TryRewrite(topLevelProject, topLevelProjectBinding, resultSelector, out rewrittenExpression))
{
result = rewrittenExpression;
}
parent._bindingContext.PopBindingScope();
}
return result;
}
internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Fail("unreachable code");
return null;
}
}
private sealed class GroupJoinTranslator : SequenceMethodTranslator
{
internal GroupJoinTranslator()
: base(SequenceMethod.GroupJoin)
{
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// o.GroupJoin(i, ok => outerKeySelector, ik => innerKeySelector, (o, i) => projection)
// -->
// SELECT projection(o, i)
// FROM (
// SELECT o, (SELECT i FROM i WHERE o.outerKeySelector = i.innerKeySelector) as i
// FROM o)
// translate inputs
CqtExpression outer = parent.TranslateSet(call.Arguments[0]);
CqtExpression inner = parent.TranslateSet(call.Arguments[1]);
// translate key selectors
DbExpressionBinding outerBinding;
DbExpressionBinding innerBinding;
LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2);
LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3);
CqtExpression outerSelector = parent.TranslateLambda(
outerLambda, outer, out outerBinding);
CqtExpression innerSelector = parent.TranslateLambda(
innerLambda, inner, out innerBinding);
// create innermost SELECT i FROM i WHERE ...
if (!TypeSemantics.IsEqualComparable(outerSelector.ResultType) ||
!TypeSemantics.IsEqualComparable(innerSelector.ResultType))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
CqtExpression nestedCollection = parent.Filter(innerBinding,
parent.CreateEqualsExpression(outerSelector, innerSelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type));
// create "join" SELECT o, (nestedCollection)
const string outerColumn = "o";
const string innerColumn = "i";
List> recordColumns = new List>(2);
recordColumns.Add(new KeyValuePair(outerColumn, outerBinding.Variable));
recordColumns.Add(new KeyValuePair(innerColumn, nestedCollection));
CqtExpression joinProjection = parent._commandTree.CreateNewRowExpression(recordColumns);
CqtExpression joinProject = parent._commandTree.CreateProjectExpression(outerBinding, joinProjection);
DbExpressionBinding joinProjectBinding = parent._commandTree.CreateExpressionBinding(joinProject);
// create property expressions for the outer and inner terms to bind to the parameters to the
// group join selector
CqtExpression outerProperty = parent._commandTree.CreatePropertyExpression(outerColumn,
joinProjectBinding.Variable);
CqtExpression innerProperty = parent._commandTree.CreatePropertyExpression(innerColumn,
joinProjectBinding.Variable);
// push the inner and the outer terms into the binding scope
LambdaExpression linqSelector = parent.GetLambdaExpression(call, 4);
parent._bindingContext.PushBindingScope(
new Binding(linqSelector.Parameters[0], outerProperty),
new Binding(linqSelector.Parameters[1], innerProperty));
// translate the selector
CqtExpression selectorProject = parent.TranslateExpression(linqSelector.Body);
// pop the binding scope
parent._bindingContext.PopBindingScope();
// create the selector projection
CqtExpression selector = parent._commandTree.CreateProjectExpression(joinProjectBinding, selectorProject);
return selector;
}
}
private abstract class OrderByTranslatorBase : OneLambdaTranslator
{
private readonly bool _ascending;
protected OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods)
: base(methods)
{
_ascending = ascending;
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
List keys = new List(1);
DbSortClause sortSpec = parent._commandTree.CreateSortClause(lambda, _ascending);
keys.Add(sortSpec);
DbSortExpression sort = parent.Sort(sourceBinding, keys);
return sort;
}
}
private sealed class OrderByTranslator : OrderByTranslatorBase
{
internal OrderByTranslator() : base(true, SequenceMethod.OrderBy) { }
}
private sealed class OrderByDescendingTranslator : OrderByTranslatorBase
{
internal OrderByDescendingTranslator() : base(false, SequenceMethod.OrderByDescending) { }
}
// Note: because we need to "push-down" the expression binding for ThenBy, this class
// does not inherit from OneLambdaTranslator, although it is similar.
private abstract class ThenByTranslatorBase : SequenceMethodTranslator
{
private readonly bool _ascending;
protected ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods)
: base(methods)
{
_ascending = ascending;
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(2 == call.Arguments.Count);
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
if (DbExpressionKind.Sort != source.ExpressionKind)
{
throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.ELinq_ThenByDoesNotFollowOrderBy);
}
DbSortExpression sortExpression = (DbSortExpression)source;
// retrieve information about existing sort
DbExpressionBinding binding = sortExpression.Input;
// get information on new sort term
LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1);
ParameterExpression parameter = lambdaExpression.Parameters[0];
// push-down the binding scope information and translate the new sort key
parent._bindingContext.PushBindingScope(new Binding(parameter, binding.Variable));
CqtExpression lambda = parent.TranslateExpression(lambdaExpression.Body);
parent._bindingContext.PopBindingScope();
// create a new sort expression
List keys = new List(sortExpression.SortOrder);
keys.Add(new DbSortClause(lambda, _ascending, null));
sortExpression = parent.Sort(binding, keys);
return sortExpression;
}
}
private sealed class ThenByTranslator : ThenByTranslatorBase
{
internal ThenByTranslator() : base(true, SequenceMethod.ThenBy) { }
}
private sealed class ThenByDescendingTranslator : ThenByTranslatorBase
{
internal ThenByDescendingTranslator() : base(false, SequenceMethod.ThenByDescending) { }
}
#endregion
}
}
}
// File provided for Reference Use Only by Microsoft Corporation (c) 2007.
//----------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
// @owner [....], [....]
//---------------------------------------------------------------------
using System.Data.Common.CommandTrees;
using System.Collections.Generic;
using CqtExpression = System.Data.Common.CommandTrees.DbExpression;
using LinqExpression = System.Linq.Expressions.Expression;
using System.Diagnostics;
using System.Data.Metadata.Edm;
using System.Linq.Expressions;
using System.Reflection;
using System.Linq;
using System.Data.Entity;
using System.Data.Common;
using System.Globalization;
namespace System.Data.Objects.ELinq
{
internal sealed partial class ExpressionConverter
{
///
/// Translates System.Linq.Expression.MethodCallExpression to System.Data.Common.CommandTrees.DbExpression
///
private sealed class MethodCallTranslator : TypedTranslator
{
internal MethodCallTranslator()
: base(ExpressionType.Call) { }
protected override CqtExpression TypedTranslate(ExpressionConverter parent, MethodCallExpression linq)
{
// check if this is a known sequence method
SequenceMethod sequenceMethod;
SequenceMethodTranslator sequenceTranslator;
if (ReflectionUtil.TryIdentifySequenceMethod(linq.Method, out sequenceMethod) &&
s_sequenceTranslators.TryGetValue(sequenceMethod, out sequenceTranslator))
{
return sequenceTranslator.Translate(parent, linq, sequenceMethod);
}
// check if this is a known method
CallTranslator callTranslator;
if (TryGetCallTranslator(linq.Method, out callTranslator))
{
return callTranslator.Translate(parent, linq);
}
// check if this is an ObjectQuery<> builder method
Type declaringType = linq.Method.DeclaringType;
if (linq.Method.IsPublic &&
null != declaringType &&
declaringType.IsGenericType &&
typeof(ObjectQuery<>) == declaringType.GetGenericTypeDefinition())
{
ObjectQueryCallTranslator builderTranslator;
if (s_objectQueryTranslators.TryGetValue(linq.Method.Name, out builderTranslator))
{
return builderTranslator.Translate(parent, linq);
}
}
// fall back on the default translator
return s_defaultTranslator.Translate(parent, linq);
}
#region Static members and initializers
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
// initialize fall-back translator
private static readonly CallTranslator s_defaultTranslator = new DefaultTranslator();
private static readonly Dictionary s_methodTranslators = InitializeMethodTranslators();
private static readonly Dictionary s_sequenceTranslators = InitializeSequenceMethodTranslators();
private static readonly Dictionary s_objectQueryTranslators = InitializeObjectQueryTranslators();
private static bool s_vbMethodsInitialized;
private static readonly object s_vbInitializerLock = new object();
private static Dictionary InitializeMethodTranslators()
{
// initialize translators for specific methods (e.g., Int32.op_Equality)
Dictionary methodTranslators = new Dictionary();
foreach (CallTranslator translator in GetCallTranslators())
{
foreach (MethodInfo method in translator.Methods)
{
methodTranslators.Add(method, translator);
}
}
return methodTranslators;
}
private static Dictionary InitializeSequenceMethodTranslators()
{
// initialize translators for sequence methods (e.g., Sequence.Select)
Dictionary sequenceTranslators = new Dictionary();
foreach (SequenceMethodTranslator translator in GetSequenceMethodTranslators())
{
foreach (SequenceMethod method in translator.Methods)
{
sequenceTranslators.Add(method, translator);
}
}
return sequenceTranslators;
}
private static Dictionary InitializeObjectQueryTranslators()
{
// initialize translators for object query methods (e.g. ObjectQuery.OfType(), ObjectQuery.Include(string) )
Dictionary objectQueryCallTranslators = new Dictionary(StringComparer.Ordinal);
foreach (ObjectQueryCallTranslator translator in GetObjectQueryCallTranslators())
{
objectQueryCallTranslators[translator.MethodName] = translator;
}
return objectQueryCallTranslators;
}
///
/// Tries to get a translator for the given method info.
/// If the given method info corresponds to a Visual Basic property,
/// it also initializes the Visual Basic translators if they have not been initialized
///
///
///
///
private static bool TryGetCallTranslator(MethodInfo methodInfo, out CallTranslator callTranslator)
{
if (s_methodTranslators.TryGetValue(methodInfo, out callTranslator))
{
return true;
}
// check if this is the visual basic assembly
if (s_visualBasicAssemblyFullName == methodInfo.DeclaringType.Assembly.FullName)
{
lock (s_vbInitializerLock)
{
if (!s_vbMethodsInitialized)
{
InitializeVBMethods(methodInfo.DeclaringType.Assembly);
s_vbMethodsInitialized = true;
}
// try again
return s_methodTranslators.TryGetValue(methodInfo, out callTranslator);
}
}
callTranslator = null;
return false;
}
private static void InitializeVBMethods(Assembly vbAssembly)
{
Debug.Assert(!s_vbMethodsInitialized);
foreach (CallTranslator translator in GetVisualBasicCallTranslators(vbAssembly))
{
foreach (MethodInfo method in translator.Methods)
{
s_methodTranslators.Add(method, translator);
}
}
}
private static IEnumerable GetVisualBasicCallTranslators(Assembly vbAssembly)
{
yield return new VBCanonicalFunctionDefaultTranslator(vbAssembly);
yield return new VBCanonicalFunctionRenameTranslator(vbAssembly);
yield return new VBDatePartTranslator(vbAssembly);
}
private static IEnumerable GetCallTranslators()
{
yield return new CanonicalFunctionDefaultTranslator();
yield return new ContainsTranslator();
yield return new StartsWithTranslator();
yield return new EndsWithTranslator();
yield return new IndexOfTranslator();
yield return new SubstringTranslator();
yield return new RemoveTranslator();
yield return new InsertTranslator();
yield return new IsNullOrEmptyTranslator();
yield return new StringConcatTranslator();
yield return new TrimStartTranslator();
yield return new TrimEndTranslator();
}
private static IEnumerable GetSequenceMethodTranslators()
{
yield return new ConcatTranslator();
yield return new UnionTranslator();
yield return new IntersectTranslator();
yield return new ExceptTranslator();
yield return new DistinctTranslator();
yield return new WhereTranslator();
yield return new SelectTranslator();
yield return new OrderByTranslator();
yield return new OrderByDescendingTranslator();
yield return new ThenByTranslator();
yield return new ThenByDescendingTranslator();
yield return new SelectManyTranslator();
yield return new AnyTranslator();
yield return new AnyPredicateTranslator();
yield return new AllTranslator();
yield return new JoinTranslator();
yield return new GroupByTranslator();
yield return new MaxTranslator();
yield return new MinTranslator();
yield return new AverageTranslator();
yield return new SumTranslator();
yield return new CountTranslator();
yield return new LongCountTranslator();
yield return new CastMethodTranslator();
yield return new GroupJoinTranslator();
yield return new OfTypeTranslator();
yield return new SingleTranslatorNotSupported();
yield return new PassthroughTranslator();
yield return new FirstTranslator();
yield return new FirstPredicateTranslator();
yield return new FirstOrDefaultTranslator();
yield return new FirstOrDefaultPredicateTranslator();
yield return new TakeTranslator();
yield return new SkipTranslator();
}
private static IEnumerable GetObjectQueryCallTranslators()
{
yield return new ObjectQueryBuilderDistinctTranslator();
yield return new ObjectQueryBuilderExceptTranslator();
yield return new ObjectQueryBuilderFirstTranslator();
yield return new ObjectQueryIncludeTranslator();
yield return new ObjectQueryBuilderIntersectTranslator();
yield return new ObjectQueryBuilderOfTypeTranslator();
yield return new ObjectQueryBuilderUnionTranslator();
}
#endregion
#region Method translators
private abstract class CallTranslator
{
private readonly IEnumerable _methods;
protected CallTranslator(params MethodInfo[] methods) { _methods = methods; }
protected CallTranslator(IEnumerable methods) { _methods = methods; }
internal IEnumerable Methods { get { return _methods; } }
internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call);
public override string ToString()
{
return GetType().Name;
}
}
private abstract class ObjectQueryCallTranslator : CallTranslator
{
private readonly string _methodName;
protected ObjectQueryCallTranslator(string methodName)
{
_methodName = methodName;
}
internal string MethodName { get { return _methodName; } }
}
private abstract class ObjectQueryBuilderCallTranslator : ObjectQueryCallTranslator
{
private readonly SequenceMethodTranslator _translator;
protected ObjectQueryBuilderCallTranslator(string methodName, SequenceMethod sequenceEquivalent)
: base(methodName)
{
bool translatorFound = s_sequenceTranslators.TryGetValue(sequenceEquivalent, out _translator);
Debug.Assert(translatorFound, "Translator not found for " + sequenceEquivalent.ToString());
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return _translator.Translate(parent, call);
}
}
private sealed class ObjectQueryBuilderUnionTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderUnionTranslator()
: base("Union", SequenceMethod.Union)
{
}
}
private sealed class ObjectQueryBuilderIntersectTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderIntersectTranslator()
: base("Intersect", SequenceMethod.Intersect)
{
}
}
private sealed class ObjectQueryBuilderExceptTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderExceptTranslator()
: base("Except", SequenceMethod.Except)
{
}
}
private sealed class ObjectQueryBuilderDistinctTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderDistinctTranslator()
: base("Distinct", SequenceMethod.Distinct)
{
}
}
private sealed class ObjectQueryBuilderOfTypeTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderOfTypeTranslator()
: base("OfType", SequenceMethod.OfType)
{
}
}
private sealed class ObjectQueryBuilderFirstTranslator : ObjectQueryBuilderCallTranslator
{
internal ObjectQueryBuilderFirstTranslator()
: base("First", SequenceMethod.First)
{
}
}
private sealed class ObjectQueryIncludeTranslator : ObjectQueryCallTranslator
{
internal ObjectQueryIncludeTranslator()
: base("Include")
{
}
internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Object != null && call.Arguments.Count == 1 && call.Arguments[0] != null && call.Arguments[0].Type.Equals(typeof(string)), "Invalid Include arguments?");
CqtExpression queryExpression = parent.TranslateExpression(call.Object);
Span span;
if (!parent.TryGetSpan(queryExpression, out span))
{
span = null;
}
CqtExpression arg = parent.TranslateExpression(call.Arguments[0]);
string includePath = null;
if (arg.ExpressionKind == DbExpressionKind.Constant)
{
includePath = (string)((DbConstantExpression)arg).Value;
}
else
{
// The 'Include' method implementation on ELinqQueryState creates
// a method call expression with a string constant argument taking
// the value of the string argument passed to ObjectQuery.Include,
// and so this is the only supported pattern here.
throw EntityUtil.NotSupported(Entity.Strings.ELinq_UnsupportedInclude);
}
return parent.AddSpanMapping(queryExpression, Span.IncludeIn(span, includePath));
}
}
private sealed class DefaultTranslator : CallTranslator
{
internal DefaultTranslator() : base() { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
MethodInfo suggestedMethodInfo;
if (TryGetAlternativeMethod(call.Method, out suggestedMethodInfo))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethodSuggestedAlternative(call.Method, suggestedMethodInfo));
}
//The default error message
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedMethod(call.Method));
}
#region Static Members
private static readonly Dictionary s_alternativeMethods = InitializeAlternateMethodInfos();
private static bool s_vbMethodsInitialized;
private static readonly object s_vbInitializerLock = new object();
///
/// Tries to check whether there is an alternative method suggested insted of the given unsupported one.
///
///
///
///
private static bool TryGetAlternativeMethod(MethodInfo originalMethodInfo, out MethodInfo suggestedMethodInfo)
{
if (s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo))
{
return true;
}
// check if this is the visual basic assembly
if (s_visualBasicAssemblyFullName == originalMethodInfo.DeclaringType.Assembly.FullName)
{
lock (s_vbInitializerLock)
{
if (!s_vbMethodsInitialized)
{
InitializeVBMethods(originalMethodInfo.DeclaringType.Assembly);
s_vbMethodsInitialized = true;
}
// try again
return s_alternativeMethods.TryGetValue(originalMethodInfo, out suggestedMethodInfo);
}
}
suggestedMethodInfo = null;
return false;
}
///
/// Initializes the dictionary of alternative methods.
/// Currently, it simply initializes an empty dictionary.
///
///
private static Dictionary InitializeAlternateMethodInfos()
{
return new Dictionary(1);
}
///
/// Populates the dictionary of alternative methods with the VB methods
///
///
private static void InitializeVBMethods(Assembly vbAssembly)
{
Debug.Assert(!s_vbMethodsInitialized);
//Handle { Mid(arg1, ar2), Mid(arg1, arg2, arg3) }
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
s_alternativeMethods.Add(
stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null),
stringsType.GetMethod("Mid", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int), typeof(int) }, null));
}
#endregion
}
private sealed class CanonicalFunctionDefaultTranslator : CallTranslator
{
internal CanonicalFunctionDefaultTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
//Math functions
yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Math).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(double) }, null);
//Decimal functions
yield return typeof(Decimal).GetMethod("Floor", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Decimal).GetMethod("Ceiling", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
yield return typeof(Decimal).GetMethod("Round", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(decimal) }, null);
//String functions
yield return typeof(String).GetMethod("Replace", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(String), typeof(String) }, null);
yield return typeof(String).GetMethod("ToLower", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
yield return typeof(String).GetMethod("ToUpper", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
yield return typeof(String).GetMethod("Trim", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { }, null);
}
// Default translator for method calls into canonical functions.
// Translation:
// MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn)
// this.MethodName(arg1, arg2, .., argn) -> MethodName(this, arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
LinqExpression[] linqArguments;
if (!call.Method.IsStatic)
{
Debug.Assert(call.Object != null, "Instance method without this");
List arguments = new List(call.Arguments.Count + 1);
arguments.Add(call.Object);
arguments.AddRange(call.Arguments);
linqArguments = arguments.ToArray();
}
else
{
linqArguments = call.Arguments.ToArray();
}
return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, linqArguments);
}
}
#region System.String Method Translators
private sealed class ContainsTranslator : CallTranslator
{
internal ContainsTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Contains", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.Contains(argument) -> IndexOf(argument, object) > 0
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateGreaterThanExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(0));
return comparisonExpression;
}
}
private sealed class IndexOfTranslator : CallTranslator
{
internal IndexOfTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("IndexOf", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// IndexOf(arg1) -> IndexOf(arg1, this) - 1
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IndexOf");
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
CqtExpression minusExpression = parent._commandTree.CreateMinusExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(1));
return minusExpression;
}
}
private sealed class StartsWithTranslator : CallTranslator
{
internal StartsWithTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("StartsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.StartsWith(argument) -> IndexOf(argument, object) == 1
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression indexOfExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.IndexOf, call, call.Arguments[0], call.Object);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(indexOfExpression,
parent._commandTree.CreateConstantExpression(1));
return comparisonExpression;
}
}
private sealed class EndsWithTranslator : CallTranslator
{
internal EndsWithTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("EndsWith", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(string) }, null);
}
// Translation:
// object.EndsWith(argument) -> Right(object, Length(argument)) == argument
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
DbFunctionExpression lengthExpression = parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]);
DbExpression rightExpression = parent.CreateCanonicalFunction(ExpressionConverter.Right, call,
parent.TranslateExpression(call.Object),
lengthExpression);
DbComparisonExpression comparisonExpression = parent._commandTree.CreateEqualsExpression(
rightExpression,
parent.TranslateExpression(call.Arguments[0]));
return comparisonExpression;
}
}
private sealed class SubstringTranslator : CallTranslator
{
internal SubstringTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null);
yield return typeof(String).GetMethod("Substring", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null);
}
// Translation:
// Substring(arg1) -> Substring(this, arg1+1, Length(this) - arg1))
// Substring(arg1, arg2) -> Substring(this, arg1+1, arg2)
//
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Substring");
CqtExpression target = parent.TranslateExpression(call.Object);
CqtExpression fromIndex = parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent._commandTree.CreateConstantExpression(1));
CqtExpression length;
if (call.Arguments.Count == 1)
{
length = parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent.TranslateExpression(call.Arguments[0]));
}
else
{
length = parent.TranslateExpression(call.Arguments[1]);
}
CqtExpression substringExpression = parent.CreateCanonicalFunction(ExpressionConverter.Substring, call, target, fromIndex, length);
return substringExpression;
}
}
private sealed class RemoveTranslator : CallTranslator
{
internal RemoveTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int) }, null);
yield return typeof(String).GetMethod("Remove", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(int) }, null);
}
// Translation:
// Remove(arg1) -> Substring(this, 1, arg1)
// Remove(arg1, arg2) -> Concat(Substring(this, 1, arg1) , Substring(this, arg1 + arg2 + 1, Length(this) - (arg1 + arg2)))
// Remove(arg1, arg2) is only supported if arg2 is a non-negative integer
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1 || call.Arguments.Count == 2, "Expecting 1 or 2 arguments for String.Remove");
//Substring(this, 1, arg1)
CqtExpression result =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreateConstantExpression(1),
parent.TranslateExpression(call.Arguments[0]));
//Concat(result, Substring(this, (arg1 + arg2) +1, Length(this) - (arg1 + arg2)))
if (call.Arguments.Count == 2)
{
//If there are two arguemtns, we only support cases when the second one translates to a non-negative constant
CqtExpression translatedArgument1 = parent.TranslateExpression(call.Arguments[1]);
if (!IsNonNegativeIntegerConstant(translatedArgument1))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedStringRemoveCase(call.Method, call.Method.GetParameters()[1].Name));
}
// Build the second substring
// (arg1 + arg2) +1
CqtExpression substringStartIndex =
parent._commandTree.CreatePlusExpression(
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
translatedArgument1),
parent._commandTree.CreateConstantExpression(1));
// Length(this) - (arg1 + arg2)
CqtExpression substringLength =
parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent.TranslateExpression(call.Arguments[1])));
// Substring(this, substringStartIndex, substringLenght)
CqtExpression secondSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
substringStartIndex,
substringLength);
// result = Concat (result, secondSubstring)
result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call, result, secondSubstring);
}
return result;
}
private static bool IsNonNegativeIntegerConstant(CqtExpression argument)
{
// Check whether it is a constant of type Int32
if (argument.ExpressionKind != DbExpressionKind.Constant ||
!TypeSemantics.IsPrimitiveType(argument.ResultType, PrimitiveTypeKind.Int32))
{
return false;
}
// Check whether its value is non-negative
DbConstantExpression constantExpression = (DbConstantExpression)argument;
int value = (int)constantExpression.Value;
if (value < 0)
{
return false;
}
return true;
}
}
private sealed class InsertTranslator : CallTranslator
{
internal InsertTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Insert", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(int), typeof(string) }, null);
}
// Translation:
// Insert(startIndex, value) -> Concat(Concat(Substring(this, 1, startIndex), value), Substring(this, startIndex+1, Length(this) - startIndex))
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 2, "Expecting 2 arguments for String.Insert");
//Substring(this, 1, startIndex)
CqtExpression firstSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreateConstantExpression(1),
parent.TranslateExpression(call.Arguments[0]));
//Substring(this, startIndex+1, Length(this) - startIndex)
CqtExpression secondSubstring =
parent.CreateCanonicalFunction(ExpressionConverter.Substring, call,
parent.TranslateExpression(call.Object),
parent._commandTree.CreatePlusExpression(
parent.TranslateExpression(call.Arguments[0]),
parent._commandTree.CreateConstantExpression(1)),
parent._commandTree.CreateMinusExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Object),
parent.TranslateExpression(call.Arguments[0])));
// result = Concat( Concat (firstSubstring, value), secondSubstring )
CqtExpression result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
firstSubstring,
parent.TranslateExpression(call.Arguments[1])),
secondSubstring);
return result;
}
}
private sealed class IsNullOrEmptyTranslator : CallTranslator
{
internal IsNullOrEmptyTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("IsNullOrEmpty", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
}
// Translation:
// IsNullOrEmpty(value) -> (IsNull(value)) OR Length(value) = 0
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 1, "Expecting 1 argument for String.IsNullOrEmpty");
//IsNull(value)
CqtExpression isNullExpression =
parent._commandTree.CreateIsNullExpression(
parent.TranslateExpression(call.Arguments[0]));
//Length(value) = 0
CqtExpression emptyStringExpression =
parent._commandTree.CreateEqualsExpression(
parent.TranslateIntoCanonicalFunction(ExpressionConverter.Length, call, call.Arguments[0]),
parent._commandTree.CreateConstantExpression(0));
CqtExpression result = parent._commandTree.CreateOrExpression(isNullExpression, emptyStringExpression);
return result;
}
}
private sealed class StringConcatTranslator : CallTranslator
{
internal StringConcatTranslator()
: base(GetMethods()) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string) }, null);
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string) }, null);
yield return typeof(String).GetMethod("Concat", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(string), typeof(string), typeof(string) }, null);
}
// Translation:
// Concat (arg1, arg2) -> Concat(arg1, arg2)
// Concat (arg1, arg2, arg3) -> Concat(Concat(arg1, arg2), arg3)
// Concat (arg1, arg2, arg3, arg4) -> Concat(Concat(Concat(arg1, arg2), arg3), arg4)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count >= 2 && call.Arguments.Count <= 4, "Expecting between 2 and 4 arguments for String.Concat");
CqtExpression result = parent.TranslateExpression(call.Arguments[0]);
for (int argIndex = 1; argIndex < call.Arguments.Count; argIndex++)
{
// result = Concat(result, arg[argIndex])
result = parent.CreateCanonicalFunction(ExpressionConverter.Concat, call,
result,
parent.TranslateExpression(call.Arguments[argIndex]));
}
return result;
}
}
private abstract class TrimStartTrimEndBaseTranslator : CallTranslator
{
private string _canonicalFunctionName;
protected TrimStartTrimEndBaseTranslator(IEnumerable methods, string canonicalFunctionName)
: base(methods)
{
_canonicalFunctionName = canonicalFunctionName;
}
// Translation:
// object.MethodName -> CanonicalFunctionName(object)
// Supported only if the argument is an empty array.
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (!IsEmptyArray(call.Arguments[0]))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedTrimStartTrimEndCase(call.Method));
}
return parent.TranslateIntoCanonicalFunction(_canonicalFunctionName, call, call.Object);
}
internal static bool IsEmptyArray(LinqExpression expression)
{
if (expression.NodeType != ExpressionType.NewArrayInit)
{
return false;
}
NewArrayExpression newArray = (NewArrayExpression)expression;
if (newArray.Expressions.Count != 0)
{
return false;
}
return true;
}
}
private sealed class TrimStartTranslator : TrimStartTrimEndBaseTranslator
{
internal TrimStartTranslator()
: base(GetMethods(), ExpressionConverter.LTrim) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("TrimStart", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null);
}
}
private sealed class TrimEndTranslator : TrimStartTrimEndBaseTranslator
{
internal TrimEndTranslator()
: base(GetMethods(), ExpressionConverter.RTrim) { }
private static IEnumerable GetMethods()
{
yield return typeof(String).GetMethod("TrimEnd", BindingFlags.Public | BindingFlags.Instance, null, new Type[] { typeof(Char[]) }, null);
}
}
#endregion
#region Visual Basic Specific Translators
private sealed class VBCanonicalFunctionDefaultTranslator : CallTranslator
{
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime";
internal VBCanonicalFunctionDefaultTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
private static IEnumerable GetMethods(Assembly vbAssembly)
{
//Strings Types
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
yield return stringsType.GetMethod("Trim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("LTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("RTrim", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string) }, null);
yield return stringsType.GetMethod("Left", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null);
yield return stringsType.GetMethod("Right", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(string), typeof(int) }, null);
//DateTimeType
Type dateTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName);
yield return dateTimeType.GetMethod("Year", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Month", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Day", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Hour", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Minute", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
yield return dateTimeType.GetMethod("Second", BindingFlags.Public | BindingFlags.Static, null, new Type[] { typeof(DateTime) }, null);
}
// Default translator for vb static method calls into canonical functions.
// Translation:
// MethodName(arg1, arg2, .., argn) -> MethodName(arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return parent.TranslateIntoCanonicalFunction(call.Method.Name, call, call.Arguments.ToArray());
}
}
private sealed class VBCanonicalFunctionRenameTranslator : CallTranslator
{
private const string s_stringsTypeFullName = "Microsoft.VisualBasic.Strings";
private static readonly Dictionary s_methodNameMap = new Dictionary(4);
internal VBCanonicalFunctionRenameTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
private static IEnumerable GetMethods(Assembly vbAssembly)
{
//Strings Types
Type stringsType = vbAssembly.GetType(s_stringsTypeFullName);
yield return GetMethod(stringsType, "Len", ExpressionConverter.Length, new Type[] { typeof(string) });
yield return GetMethod(stringsType, "Mid", ExpressionConverter.Substring, new Type[] { typeof(string), typeof(int), typeof(int) });
yield return GetMethod(stringsType, "UCase", ExpressionConverter.ToUpper, new Type[] { typeof(string) });
yield return GetMethod(stringsType, "LCase", ExpressionConverter.ToLower, new Type[] { typeof(string) });
}
private static MethodInfo GetMethod(Type declaringType, string methodName, string canonicalFunctionName, Type[] argumentTypes)
{
MethodInfo methodInfo = declaringType.GetMethod(methodName, BindingFlags.Public | BindingFlags.Static, null, argumentTypes, null);
s_methodNameMap.Add(methodInfo, canonicalFunctionName);
return methodInfo;
}
// Translator for static method calls into canonical functions when only the name of the canonical function
// is different from the name of the method, but the argumens match.
// Translation:
// MethodName(arg1, arg2, .., argn) -> CanonicalFunctionName(arg1, arg2, .., argn)
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
return parent.TranslateIntoCanonicalFunction(s_methodNameMap[call.Method], call, call.Arguments.ToArray());
}
}
private sealed class VBDatePartTranslator : CallTranslator
{
private const string s_dateAndTimeTypeFullName = "Microsoft.VisualBasic.DateAndTime";
private const string s_DateIntervalFullName = "Microsoft.VisualBasic.DateInterval";
private const string s_FirstDayOfWeekFullName = "Microsoft.VisualBasic.FirstDayOfWeek";
private const string s_FirstWeekOfYearFullName = "Microsoft.VisualBasic.FirstWeekOfYear";
private static HashSet s_supportedIntervals;
internal VBDatePartTranslator(Assembly vbAssembly)
: base(GetMethods(vbAssembly)) { }
static VBDatePartTranslator()
{
s_supportedIntervals = new HashSet();
s_supportedIntervals.Add(ExpressionConverter.Year);
s_supportedIntervals.Add(ExpressionConverter.Month);
s_supportedIntervals.Add(ExpressionConverter.Day);
s_supportedIntervals.Add(ExpressionConverter.Hour);
s_supportedIntervals.Add(ExpressionConverter.Minute);
s_supportedIntervals.Add(ExpressionConverter.Second);
}
private static IEnumerable GetMethods(Assembly vbAssembly)
{
Type dateAndTimeType = vbAssembly.GetType(s_dateAndTimeTypeFullName);
Type dateIntervalEnum = vbAssembly.GetType(s_DateIntervalFullName);
Type firstDayOfWeekEnum = vbAssembly.GetType(s_FirstDayOfWeekFullName);
Type firstWeekOfYearEnum = vbAssembly.GetType(s_FirstWeekOfYearFullName);
yield return dateAndTimeType.GetMethod("DatePart", BindingFlags.Public | BindingFlags.Static, null,
new Type[] { dateIntervalEnum, typeof(DateTime), firstDayOfWeekEnum, firstWeekOfYearEnum }, null);
}
// Translation:
// DatePart(DateInterval, date, arg3, arg4) -> 'DateInterval'(date)
// Note: it is only supported for the values of DateInterval listed in s_supportedIntervals.
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(call.Arguments.Count == 4, "Expecting 4 arguments for Microsoft.VisualBasic.DateAndTime.DatePart");
ConstantExpression intervalLinqExpression = call.Arguments[0] as ConstantExpression;
if (intervalLinqExpression == null)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartNonConstantInterval(call.Method, call.Method.GetParameters()[0].Name));
}
string intervalValue = intervalLinqExpression.Value.ToString();
if (!s_supportedIntervals.Contains(intervalValue))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedVBDatePartInvalidInterval(call.Method, call.Method.GetParameters()[0].Name, intervalValue));
}
CqtExpression result = parent.TranslateIntoCanonicalFunction(intervalValue, call, call.Arguments[1]);
return result;
}
}
#endregion
#endregion
#region Sequence method translators
private abstract class SequenceMethodTranslator
{
private readonly IEnumerable _methods;
protected SequenceMethodTranslator(params SequenceMethod[] methods) { _methods = methods; }
internal IEnumerable Methods { get { return _methods; } }
internal virtual CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)
{
return Translate(parent, call);
}
internal abstract CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call);
public override string ToString()
{
return GetType().Name;
}
}
private abstract class PagingTranslator : UnarySequenceMethodTranslator
{
protected PagingTranslator(params SequenceMethod[] methods) : base(methods) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// translate count expression
Debug.Assert(call.Arguments.Count == 2, "Skip and Take must have 2 arguments");
LinqExpression linqCount = call.Arguments[1];
CqtExpression count = parent.TranslateExpression(linqCount);
// remove projections at the root of the expression and then reapply after apply paging operator
DbProjectExpression projection = null;
if (operand.ExpressionKind == DbExpressionKind.Project)
{
projection = (DbProjectExpression)operand;
operand = projection.Input.Expression;
}
// translate paging expression
DbExpression result = TranslatePagingOperator(parent, operand, count);
// reapply project as necessary
if (null != projection)
{
projection.Input.Expression = result;
result = projection;
}
return result;
}
protected abstract CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count);
}
private sealed class TakeTranslator : PagingTranslator
{
internal TakeTranslator() : base(SequenceMethod.Take) { }
protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)
{
return parent.Limit(operand, count);
}
}
private sealed class SkipTranslator : PagingTranslator
{
internal SkipTranslator() : base(SequenceMethod.Skip) { }
protected override CqtExpression TranslatePagingOperator(ExpressionConverter parent, CqtExpression operand, CqtExpression count)
{
// skip requires a sorted input
if (operand.ExpressionKind != DbExpressionKind.Sort)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_SkipWithoutOrder);
}
DbSortExpression sortedOperand = (DbSortExpression)operand;
Span sortedSpan = null;
bool hadSpan = parent.TryGetSpan(sortedOperand, out sortedSpan);
// generate a skip statement with the sort order of the original sorted input.
DbSkipExpression skip = parent.Skip(sortedOperand.Input, sortedOperand.SortOrder, count);
// If the original DbSortExpression had Span information, then this is applied
// to the newly created DbSkipExpression before returning.
if (hadSpan)
{
parent.AddSpanMapping(skip, sortedSpan);
}
return skip;
}
}
private sealed class SingleTranslatorNotSupported : SequenceMethodTranslator
{
internal SingleTranslatorNotSupported()
: base(SequenceMethod.Single, SequenceMethod.SinglePredicate,
SequenceMethod.SingleOrDefault, SequenceMethod.SingleOrDefaultPredicate) { }
internal override System.Data.Common.CommandTrees.DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedSingle);
}
}
private sealed class JoinTranslator : SequenceMethodTranslator
{
internal JoinTranslator() : base(SequenceMethod.Join) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(5 == call.Arguments.Count);
// get expressions describing inputs to the join
CqtExpression outer = parent.TranslateSet(call.Arguments[0]);
CqtExpression inner = parent.TranslateSet(call.Arguments[1]);
// get expressions describing key selectors
LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2);
LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3);
// translator key selectors
DbExpressionBinding outerBinding;
DbExpressionBinding innerBinding;
CqtExpression outerKeySelector = parent.TranslateLambda(outerLambda, outer, out outerBinding);
CqtExpression innerKeySelector = parent.TranslateLambda(innerLambda, inner, out innerBinding);
// construct join expression
if (!TypeSemantics.IsEqualComparable(outerKeySelector.ResultType) ||
!TypeSemantics.IsEqualComparable(innerKeySelector.ResultType))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
DbJoinExpression join = parent._commandTree.CreateInnerJoinExpression(
outerBinding, innerBinding,
parent.CreateEqualsExpression(outerKeySelector, innerKeySelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type));
DbExpressionBinding joinBinding = parent._commandTree.CreateExpressionBinding(join);
// get selector expression
LambdaExpression selectorLambda = parent.GetLambdaExpression(call, 4);
// create property expressions for the inner and outer
DbPropertyExpression joinOuter = parent._commandTree.CreatePropertyExpression(
outerBinding.VariableName, joinBinding.Variable);
DbPropertyExpression joinInner = parent._commandTree.CreatePropertyExpression(
innerBinding.VariableName, joinBinding.Variable);
// push outer and inner join parts into the binding scope (the order
// is irrelevant because the binding context matches based on parameter
// reference rather than ordinal)
parent._bindingContext.PushBindingScope(
new Binding(selectorLambda.Parameters[0], joinOuter),
new Binding(selectorLambda.Parameters[1], joinInner));
// translate join selector
CqtExpression selector = parent.TranslateExpression(selectorLambda.Body);
// pop binding scope
parent._bindingContext.PopBindingScope();
return parent._commandTree.CreateProjectExpression(joinBinding, selector);
}
}
private abstract class BinarySequenceMethodTranslator : SequenceMethodTranslator
{
protected BinarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (null != call.Object)
{
// instance method
Debug.Assert(1 == call.Arguments.Count);
CqtExpression left = parent.TranslateSet(call.Object);
CqtExpression right = parent.TranslateSet(call.Arguments[0]);
return TranslateBinary(parent, left, right);
}
else
{
// static extension method
Debug.Assert(2 == call.Arguments.Count);
CqtExpression left = parent.TranslateSet(call.Arguments[0]);
CqtExpression right = parent.TranslateSet(call.Arguments[1]);
return TranslateBinary(parent, left, right);
}
}
protected abstract CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right);
}
private class ConcatTranslator : BinarySequenceMethodTranslator
{
internal ConcatTranslator() : base(SequenceMethod.Concat) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.UnionAll(left, right);
}
}
private sealed class UnionTranslator : BinarySequenceMethodTranslator
{
internal UnionTranslator() : base(SequenceMethod.Union) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Distinct(parent.UnionAll(left, right));
}
}
private sealed class IntersectTranslator : BinarySequenceMethodTranslator
{
internal IntersectTranslator() : base(SequenceMethod.Intersect) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Intersect(left, right);
}
}
private sealed class ExceptTranslator : BinarySequenceMethodTranslator
{
internal ExceptTranslator() : base(SequenceMethod.Except) { }
protected override CqtExpression TranslateBinary(ExpressionConverter parent, CqtExpression left, CqtExpression right)
{
return parent.Except(left, right);
}
}
private abstract class AggregateTranslator : SequenceMethodTranslator
{
private readonly string _functionName;
private readonly bool _takesPredicate;
protected AggregateTranslator(string functionName, bool takesPredicate, params SequenceMethod[] methods)
: base(methods)
{
_takesPredicate = takesPredicate;
_functionName = functionName;
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
bool isUnary = 1 == call.Arguments.Count;
Debug.Assert(isUnary || 2 == call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Arguments[0]);
// Facet information for the return type cannot help in determining the appropriate function overload
// since no constraints on the return value are known.
TypeUsage returnType = parent.GetValueLayerType(call.Type);
LambdaExpression lambda = null;
//We save the original operand for the optimized translation
CqtExpression originalOperand = operand;
if (!isUnary)
{
lambda = parent.GetLambdaExpression(call, 1);
DbExpressionBinding sourceBinding;
CqtExpression cqtLambda = parent.TranslateLambda(lambda, operand, out sourceBinding);
if (_takesPredicate)
{
// treat the lambda as a filter
operand = parent.Filter(sourceBinding, cqtLambda);
}
else
{
// treat the lambda as a selector
operand = parent._commandTree.CreateProjectExpression(sourceBinding, cqtLambda);
}
}
operand = WrapCollectionOperand(parent, operand, returnType);
DbGroupExpressionBinding operandBinding = parent._commandTree.CreateGroupExpressionBinding(operand);
EdmFunction function = FindFunction(parent, call, returnType);
// create aggregate
List> keys = new List>(0); // no key
List> aggregates = new List>(1);
const string aggregateName = "Aggregate";
aggregates.Add(new KeyValuePair(aggregateName, // name is arbitrary (there is only one in this context)
parent._commandTree.CreateFunctionAggregate(function, operandBinding.GroupVariable)));
DbGroupByExpression aggregate = parent._commandTree.CreateGroupByExpression(
operandBinding, keys, aggregates);
DbExpressionBinding aggregateBinding = parent._commandTree.CreateExpressionBinding(aggregate);
// project result
DbPropertyExpression property = parent._commandTree.CreatePropertyExpression(
aggregateName, aggregateBinding.Variable);
DbProjectExpression projection = parent._commandTree.CreateProjectExpression(
aggregateBinding, parent.AlignTypes(property, call.Type));
// return a single element to represent the projection
DbElementExpression element = parent._commandTree.CreateElementExpression(projection);
// Try to create and log an optimized translation
TryCreateOptimizedTranslation(parent, lambda, originalOperand, function, element);
return element;
}
// If the function is over a group by, it tries to incorporate the aggregate function into the group by.
// If it does, it gives the aggregate an alias and it returns it through
private void TryCreateOptimizedTranslation(ExpressionConverter parent, LambdaExpression lambda, CqtExpression operand, EdmFunction function, CqtExpression originalTranslation)
{
//Aggregates that take predicates as arguments cannot be incorporated into a group by
if (_takesPredicate && (lambda != null))
{
return;
}
//Check whether the operand is a property over an output of grouping
if (operand.ExpressionKind != DbExpressionKind.Property)
{
return;
}
DbPropertyExpression propertyExpression = (DbPropertyExpression)operand;
if (propertyExpression.Instance.ExpressionKind != DbExpressionKind.VariableReference)
{
return;
}
DbVariableReferenceExpression inputVarRef = (DbVariableReferenceExpression)propertyExpression.Instance;
//If the input corresponding to the var ref has an optimized translation, generate an alternate translation for this as well.
CqtExpression input;
if (!parent._variableNameToInputExpression.TryGetValue(inputVarRef.VariableName, out input))
{
return;
}
DbGroupByTemplate optimizedTranslationOfInput;
if (!parent._groupByDefaultToOptimizedTranslationMap.TryGetValue(input, out optimizedTranslationOfInput))
{
return;
}
Debug.Assert(TypeSemantics.IsCollectionType(function.Parameters[0].TypeUsage), "Aggregates should always have collection arguments");
TypeUsage elementType = TypeHelpers.GetElementTypeUsage(function.Parameters[0].TypeUsage);
// The aggregate can be added to the list of aggregates.
CqtExpression aggregateArgument = optimizedTranslationOfInput.Input.GroupVariable;
if (lambda != null)
{
aggregateArgument = parent.TranslateLambda(lambda, aggregateArgument);
}
string optimizedTranslationAlias = String.Format(CultureInfo.InvariantCulture, "Aggregate{0}", optimizedTranslationOfInput.Aggregates.Count);
optimizedTranslationOfInput.Aggregates.Add(new KeyValuePair(optimizedTranslationAlias,
parent._commandTree.CreateFunctionAggregate(function,
WrapNonCollectionOperand(parent, aggregateArgument, elementType))));
//log the alias
parent._aggregateDefaultTranslationToOptimizedTranslationInfoMap.Add(originalTranslation, new KeyValuePair(optimizedTranslationOfInput, optimizedTranslationAlias));
}
// If necessary, wraps the operand to ensure the appropriate aggregate overload is called
protected virtual CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand,
TypeUsage returnType)
{
// check if the operand needs to be wrapped to ensure the correct function overload is called
if (!TypeUsageEquals(returnType, ((CollectionType)operand.ResultType.EdmType).TypeUsage))
{
DbExpressionBinding operandCastBinding = parent._commandTree.CreateExpressionBinding(operand);
DbProjectExpression operandCastProjection = parent._commandTree.CreateProjectExpression(
operandCastBinding, parent._commandTree.CreateCastExpression(operandCastBinding.Variable, returnType));
operand = operandCastProjection;
}
return operand;
}
// If necessary, wraps the operand to ensure the appropriate aggregate overload is called
protected virtual CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand,
TypeUsage returnType)
{
if (!TypeUsageEquals(returnType, operand.ResultType))
{
operand = parent._commandTree.CreateCastExpression(operand, returnType);
}
return operand;
}
// Finds the best function overload given the expected return type
protected virtual EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call,
TypeUsage argumentType)
{
List argTypes = new List(1);
// In general, we use the return type as the parameter type to align LINQ semantics
// with SQL semantics, and avoid apparent loss of precision for some LINQ aggregate operators.
// (e.g., AVG(1, 2) = 2.0, AVG((double)1, (double)2)) = 1.5)
argTypes.Add(argumentType);
return parent.FindCanonicalFunction(_functionName, argTypes, true /* isGroupAggregateFunction */, call);
}
}
private sealed class MaxTranslator : AggregateTranslator
{
internal MaxTranslator()
: base("MAX", false,
SequenceMethod.Max,
SequenceMethod.MaxSelector,
SequenceMethod.MaxInt,
SequenceMethod.MaxIntSelector,
SequenceMethod.MaxDecimal,
SequenceMethod.MaxDecimalSelector,
SequenceMethod.MaxDouble,
SequenceMethod.MaxDoubleSelector,
SequenceMethod.MaxLong,
SequenceMethod.MaxLongSelector,
SequenceMethod.MaxSingle,
SequenceMethod.MaxSingleSelector,
SequenceMethod.MaxNullableDecimal,
SequenceMethod.MaxNullableDecimalSelector,
SequenceMethod.MaxNullableDouble,
SequenceMethod.MaxNullableDoubleSelector,
SequenceMethod.MaxNullableInt,
SequenceMethod.MaxNullableIntSelector,
SequenceMethod.MaxNullableLong,
SequenceMethod.MaxNullableLongSelector,
SequenceMethod.MaxNullableSingle,
SequenceMethod.MaxNullableSingleSelector)
{
}
}
private sealed class MinTranslator : AggregateTranslator
{
internal MinTranslator()
: base("MIN", false,
SequenceMethod.Min,
SequenceMethod.MinSelector,
SequenceMethod.MinDecimal,
SequenceMethod.MinDecimalSelector,
SequenceMethod.MinDouble,
SequenceMethod.MinDoubleSelector,
SequenceMethod.MinInt,
SequenceMethod.MinIntSelector,
SequenceMethod.MinLong,
SequenceMethod.MinLongSelector,
SequenceMethod.MinNullableDecimal,
SequenceMethod.MinSingle,
SequenceMethod.MinSingleSelector,
SequenceMethod.MinNullableDecimalSelector,
SequenceMethod.MinNullableDouble,
SequenceMethod.MinNullableDoubleSelector,
SequenceMethod.MinNullableInt,
SequenceMethod.MinNullableIntSelector,
SequenceMethod.MinNullableLong,
SequenceMethod.MinNullableLongSelector,
SequenceMethod.MinNullableSingle,
SequenceMethod.MinNullableSingleSelector)
{
}
}
private sealed class AverageTranslator : AggregateTranslator
{
internal AverageTranslator()
: base("AVG", false,
SequenceMethod.AverageDecimal,
SequenceMethod.AverageDecimalSelector,
SequenceMethod.AverageDouble,
SequenceMethod.AverageDoubleSelector,
SequenceMethod.AverageInt,
SequenceMethod.AverageIntSelector,
SequenceMethod.AverageLong,
SequenceMethod.AverageLongSelector,
SequenceMethod.AverageSingle,
SequenceMethod.AverageSingleSelector,
SequenceMethod.AverageNullableDecimal,
SequenceMethod.AverageNullableDecimalSelector,
SequenceMethod.AverageNullableDouble,
SequenceMethod.AverageNullableDoubleSelector,
SequenceMethod.AverageNullableInt,
SequenceMethod.AverageNullableIntSelector,
SequenceMethod.AverageNullableLong,
SequenceMethod.AverageNullableLongSelector,
SequenceMethod.AverageNullableSingle,
SequenceMethod.AverageNullableSingleSelector)
{
}
}
private sealed class SumTranslator : AggregateTranslator
{
internal SumTranslator()
: base("SUM", false,
SequenceMethod.SumDecimal,
SequenceMethod.SumDecimalSelector,
SequenceMethod.SumDouble,
SequenceMethod.SumDoubleSelector,
SequenceMethod.SumInt,
SequenceMethod.SumIntSelector,
SequenceMethod.SumLong,
SequenceMethod.SumLongSelector,
SequenceMethod.SumSingle,
SequenceMethod.SumSingleSelector,
SequenceMethod.SumNullableDecimal,
SequenceMethod.SumNullableDecimalSelector,
SequenceMethod.SumNullableDouble,
SequenceMethod.SumNullableDoubleSelector,
SequenceMethod.SumNullableInt,
SequenceMethod.SumNullableIntSelector,
SequenceMethod.SumNullableLong,
SequenceMethod.SumNullableLongSelector,
SequenceMethod.SumNullableSingle,
SequenceMethod.SumNullableSingleSelector)
{
}
}
private abstract class CountTranslatorBase : AggregateTranslator
{
protected CountTranslatorBase(string functionName, params SequenceMethod[] methods)
: base(functionName, true, methods)
{
}
protected override CqtExpression WrapCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)
{
// always count a constant value
DbExpressionBinding operandBinding = parent._commandTree.CreateExpressionBinding(operand);
DbProjectExpression constantProject = parent._commandTree.CreateProjectExpression(
operandBinding, parent._commandTree.CreateTrueExpression());
return constantProject;
}
protected override CqtExpression WrapNonCollectionOperand(ExpressionConverter parent, CqtExpression operand, TypeUsage returnType)
{
// always count a constant value
DbExpression constantExpression = parent._commandTree.CreateConstantExpression(1);
if (!TypeUsageEquals(constantExpression.ResultType, returnType))
{
constantExpression = parent._commandTree.CreateCastExpression(constantExpression, returnType);
}
return constantExpression;
}
protected override EdmFunction FindFunction(ExpressionConverter parent, MethodCallExpression call,
TypeUsage argumentType)
{
// For most ELinq aggregates, the argument type is the return type. For "count", the
// argument type is always Boolean, since we project a constant Boolean value in WrapCollectionOperand.
TypeUsage booleanTypeUsage = parent._commandTree.CreateTrueExpression().ResultType;
return base.FindFunction(parent, call, booleanTypeUsage);
}
}
private sealed class CountTranslator : CountTranslatorBase
{
internal CountTranslator()
: base("COUNT", SequenceMethod.Count, SequenceMethod.CountPredicate)
{
}
}
private sealed class LongCountTranslator : CountTranslatorBase
{
internal LongCountTranslator()
: base("BIGCOUNT", SequenceMethod.LongCount, SequenceMethod.LongCountPredicate)
{
}
}
private abstract class UnarySequenceMethodTranslator : SequenceMethodTranslator
{
protected UnarySequenceMethodTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
if (null != call.Object)
{
// instance method
Debug.Assert(0 <= call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Object);
return TranslateUnary(parent, operand, call);
}
else
{
// static extension method
Debug.Assert(1 <= call.Arguments.Count);
CqtExpression operand = parent.TranslateSet(call.Arguments[0]);
return TranslateUnary(parent, operand, call);
}
}
protected abstract CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call);
}
private sealed class PassthroughTranslator : UnarySequenceMethodTranslator
{
internal PassthroughTranslator() : base(SequenceMethod.AsQueryableGeneric, SequenceMethod.AsQueryable, SequenceMethod.AsEnumerable) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// make sure the operand has collection type to avoid treating (for instance) String as a
// sub-query
if (TypeSemantics.IsCollectionType(operand.ResultType))
{
return operand;
}
else
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedPassthrough(
call.Method.Name, operand.ResultType.EdmType.Name));
}
}
}
private sealed class OfTypeTranslator : UnarySequenceMethodTranslator
{
internal OfTypeTranslator() : base(SequenceMethod.OfType) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
Type clrType = call.Method.GetGenericArguments()[0];
TypeUsage modelType;
// If the model type does not exist in the perspective or is not either an EntityType
// or a ComplexType, fail - OfType() is not a valid operation on scalars,
// enumerations, collections, etc.
if (!parent.TryGetValueLayerType(clrType, out modelType) ||
!(TypeSemantics.IsEntityType(modelType) || TypeSemantics.IsComplexType(modelType)))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_InvalidOfTypeResult(DescribeClrType(clrType)));
}
// Create an of type expression to filter the original query to include
// only those results that are of the specified type.
CqtExpression ofTypeExpression = parent.OfType(operand, modelType);
return ofTypeExpression;
}
}
private sealed class DistinctTranslator : UnarySequenceMethodTranslator
{
internal DistinctTranslator() : base(SequenceMethod.Distinct) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
return parent.Distinct(operand);
}
}
private sealed class AnyTranslator : UnarySequenceMethodTranslator
{
internal AnyTranslator() : base(SequenceMethod.Any) { }
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand,
MethodCallExpression call)
{
// "Any" is equivalent to "exists".
return parent._commandTree.CreateNotExpression(
parent._commandTree.CreateIsEmptyExpression(operand));
}
}
private abstract class OneLambdaTranslator : SequenceMethodTranslator
{
internal OneLambdaTranslator(params SequenceMethod[] methods) : base(methods) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
CqtExpression source;
DbExpressionBinding sourceBinding;
CqtExpression lambda;
return Translate(parent, call, out source, out sourceBinding, out lambda);
}
// Helper method for tranlsation
protected CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, out CqtExpression source, out DbExpressionBinding sourceBinding, out CqtExpression lambda)
{
Debug.Assert(2 <= call.Arguments.Count);
// translate source
source = parent.TranslateExpression(call.Arguments[0]);
// translate lambda expression
LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1);
lambda = parent.TranslateLambda(lambdaExpression, source, out sourceBinding);
return TranslateOneLambda(parent, sourceBinding, lambda);
}
protected abstract CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda);
}
private sealed class AnyPredicateTranslator : OneLambdaTranslator
{
internal AnyPredicateTranslator() : base(SequenceMethod.AnyPredicate) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent._commandTree.CreateAnyExpression(sourceBinding, lambda);
}
}
private sealed class AllTranslator : OneLambdaTranslator
{
internal AllTranslator() : base(SequenceMethod.All) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent._commandTree.CreateAllExpression(sourceBinding, lambda);
}
}
private sealed class WhereTranslator : OneLambdaTranslator
{
internal WhereTranslator() : base(SequenceMethod.Where) { }
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Filter(sourceBinding, lambda);
}
}
private sealed class SelectTranslator : OneLambdaTranslator
{
internal SelectTranslator() : base(SequenceMethod.Select) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
CqtExpression source;
DbExpressionBinding sourceBinding;
CqtExpression lambda;
CqtExpression result = Translate(parent, call, out source, out sourceBinding, out lambda);
// If the select if over a GroupBy, check whether an optimized translation can be produced
// The default translation is to 'manually' do group by, the optimized is translating into a
// DbGroupByExpression
CqtExpression rewrittenExpression;
if (parent.TryRewrite(source, sourceBinding, lambda, out rewrittenExpression))
{
result = rewrittenExpression;
}
return result;
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Project(sourceBinding, lambda);
}
}
private abstract class FirstTranslatorBase : UnarySequenceMethodTranslator
{
protected FirstTranslatorBase(bool orDefault, params SequenceMethod[] methods)
: base(methods)
{
_orDefault = orDefault;
}
private readonly bool _orDefault;
protected override CqtExpression TranslateUnary(ExpressionConverter parent, CqtExpression operand, MethodCallExpression call)
{
// Apply a Limit(1) expression to restrict the result to a single element
// (this returns an empty set if the input set is initially empty).
CqtExpression result = parent.Limit(operand, parent._commandTree.CreateConstantExpression(1));
// If this First() or FirstOrDefault() operation is the root of the query,
// then the evaluation is performed in the client over the resulting set,
// to provide the same semantics as Linq to Objects. Otherwise, an Element
// expression is applied to retrieve the single element (or null, if empty)
// from the output set.
if (!parent.IsQueryRoot(call))
{
if (_orDefault)
{
result = parent._commandTree.CreateElementExpression(result);
result = AddDefaultCase(parent, result, call.Type);
}
else
{
// First is not allowed anywhere other than the root of the query,
// since First() operations logically executed in the store will not
// throw an exception if the input set is empty (typically they will
// simply produce a null result).
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst);
}
}
// Span is preserved over First/FirstOrDefault with or without a predicate
Span inputSpan = null;
if (parent.TryGetSpan(operand, out inputSpan))
{
parent.AddSpanMapping(result, inputSpan);
}
return result;
}
internal static CqtExpression AddDefaultCase(ExpressionConverter parent, CqtExpression element, Type elementType)
{
// Retrieve default value.
object defaultValue = TypeSystem.GetDefaultValue(elementType);
if (null == defaultValue)
{
// Already null, which is the implicit default for DbElementExpression
return element;
}
// Otherwise, use the default value for the type
List whenExpressions = new List(1);
whenExpressions.Add(parent.CreateIsNullExpression(element, elementType));
List thenExpressions = new List(1);
thenExpressions.Add(parent._commandTree.CreateConstantExpression(defaultValue));
DbCaseExpression caseExpression = parent._commandTree.CreateCaseExpression(
whenExpressions, thenExpressions, element);
return caseExpression;
}
}
private sealed class FirstTranslator : FirstTranslatorBase
{
internal FirstTranslator() : base(false, SequenceMethod.First) { }
}
private sealed class FirstOrDefaultTranslator : FirstTranslatorBase
{
internal FirstOrDefaultTranslator() : base(true, SequenceMethod.FirstOrDefault) { }
}
private abstract class FirstPredicateTranslatorBase : OneLambdaTranslator
{
protected FirstPredicateTranslatorBase(bool orDefault, params SequenceMethod[] methods)
: base(methods)
{
_orDefault = orDefault;
}
private readonly bool _orDefault;
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// Convert the input set and the predicate into a filter expression
CqtExpression input = base.Translate(parent, call);
// If this First/FirstOrDefault is the root of the query,
// then the actual result will be produced by evaluated by
// calling First() or FirstOrDefault() on the filtered input set,
// which is limited to at most one element by applying a Limit(1) expression.
if (parent.IsQueryRoot(call))
{
// Calling ExpressionConverter.Limit propagates the Span.
return parent.Limit(input, parent._commandTree.CreateConstantExpression(1));
}
else
{
if (_orDefault)
{
CqtExpression element = parent._commandTree.CreateElementExpression(input);
element = FirstTranslatorBase.AddDefaultCase(parent, element, call.Type);
// Span is preserved over First/FirstOrDefault with or without a predicate
Span inputSpan = null;
if (parent.TryGetSpan(input, out inputSpan))
{
parent.AddSpanMapping(element, inputSpan);
}
return element;
}
else
{
// First (with or without a predicate) is not allowed anywhere other
// than the root of the query, since First(predicate) operations
// logically executed in the store will not throw an exception if the
// input set is empty (typically they will simply produce a null result).
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedNestedFirst);
}
}
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
return parent.Filter(sourceBinding, lambda);
}
}
private sealed class FirstPredicateTranslator : FirstPredicateTranslatorBase
{
internal FirstPredicateTranslator() : base(false, SequenceMethod.FirstPredicate) { }
}
private sealed class FirstOrDefaultPredicateTranslator : FirstPredicateTranslatorBase
{
internal FirstOrDefaultPredicateTranslator() : base(true, SequenceMethod.FirstOrDefaultPredicate) { }
}
private sealed class SelectManyTranslator : OneLambdaTranslator
{
internal SelectManyTranslator() : base(SequenceMethod.SelectMany, SequenceMethod.SelectManyResultSelector) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
bool hasSelector = 3 == call.Arguments.Count;
CqtExpression crossApply = base.Translate(parent, call);
// perform a cross apply to implement the core logic for SelectMany (this translates the collection selector):
// SelectMany(i, Func> collectionSelector) =>
// i CROSS APPLY collectionSelector(i)
// The cross-apply yields a collection from which we yield either the right hand side (when
// no explicit resultSelector is given) or over which we apply the resultSelector Lambda expression.
DbExpressionBinding crossApplyBinding = parent._commandTree.CreateExpressionBinding(crossApply);
RowType crossApplyRowType = (RowType)(crossApplyBinding.Variable.ResultType.EdmType);
CqtExpression projectRight = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[1], crossApplyBinding.Variable);
CqtExpression resultProjection;
if (hasSelector)
{
CqtExpression projectLeft = parent._commandTree.CreatePropertyExpression(crossApplyRowType.Properties[0], crossApplyBinding.Variable);
LambdaExpression resultSelector = parent.GetLambdaExpression(call, 2);
// add the left and right projection terms to the binding context
parent._bindingContext.PushBindingScope(new Binding(resultSelector.Parameters[0], projectLeft),
new Binding(resultSelector.Parameters[1], projectRight));
// translate the result selector
resultProjection = parent.TranslateSet(resultSelector.Body);
// pop binding context
parent._bindingContext.PopBindingScope();
}
else
{
// project out the right hand side of the apply
resultProjection = projectRight;
}
// wrap result projection in project expression
return parent._commandTree.CreateProjectExpression(crossApplyBinding, resultProjection);
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
// elements of the inner selector should be used
lambda = parent.NormalizeSetSource(lambda);
DbExpressionBinding applyBinding = parent._commandTree.CreateExpressionBinding(lambda);
DbApplyExpression crossApply = parent._commandTree.CreateCrossApplyExpression(sourceBinding, applyBinding);
return crossApply;
}
}
private sealed class CastMethodTranslator : SequenceMethodTranslator
{
internal CastMethodTranslator() : base(SequenceMethod.Cast) { }
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// Translate source
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
// Figure out the type to cast to
Type toClrType = TypeSystem.GetElementType(call.Type);
Type fromClrType = TypeSystem.GetElementType(call.Arguments[0].Type);
// Get binding to the elements of the input source
DbExpressionBinding binding = parent._commandTree.CreateExpressionBinding(source);
CqtExpression cast = parent.CreateCastExpression(binding.Variable, toClrType, fromClrType);
return parent._commandTree.CreateProjectExpression(binding, cast);
}
}
private sealed class GroupByTranslator : SequenceMethodTranslator
{
internal GroupByTranslator()
: base(SequenceMethod.GroupBy, SequenceMethod.GroupByElementSelector, SequenceMethod.GroupByElementSelectorResultSelector,
SequenceMethod.GroupByResultSelector)
{
}
// The default translation of GroupBy is:
// SELECT d as Key, (SELECT VALUE g FROM source WHERE source.Key = d) as Group
// FROM (SELECT DISTINCT source.Key)
//
// The optimized translation is simply creating a Cqt GroupByExpression
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod)
{
// translate source
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
// translate key selector
LambdaExpression keySelectorLinq = parent.GetLambdaExpression(call, 1);
DbExpressionBinding sourceBinding;
CqtExpression keySelector = parent.TranslateLambda(keySelectorLinq, source, out sourceBinding);
// translate the key selector again in a different binding context (for the nested select)
DbExpressionBinding nestedSourceBinding;
CqtExpression nestedSelector = parent.TranslateLambda(keySelectorLinq, source, out nestedSourceBinding);
// create distinct expression
if (!TypeSemantics.IsEqualComparable(keySelector.ResultType))
{
// to avoid confusing error message about the "distinct" type, pre-emptively raise an exception
// about the group by key selector
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
CqtExpression distinct = parent.Distinct(
parent._commandTree.CreateProjectExpression(sourceBinding, keySelector));
DbExpressionBinding distinctBinding = parent._commandTree.CreateExpressionBinding(distinct);
// create group projection term
DbFilterExpression groupKeyFilter = parent.Filter(
nestedSourceBinding, parent.CreateEqualsExpression(nestedSelector, distinctBinding.Variable, EqualsPattern.PositiveNullEquality, keySelectorLinq.Type, keySelectorLinq.Type));
// interpret element selector if needed
CqtExpression selection = groupKeyFilter;
bool hasElementSelector = sequenceMethod == SequenceMethod.GroupByElementSelector ||
sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector;
if (hasElementSelector)
{
LambdaExpression elementSelectorLinq = parent.GetLambdaExpression(call, 2);
DbExpressionBinding elementSelectorSourceBinding;
CqtExpression elementSelector = parent.TranslateLambda(elementSelectorLinq, selection, out elementSelectorSourceBinding);
selection = parent._commandTree.CreateProjectExpression(elementSelectorSourceBinding,
elementSelector);
}
// create top level projection
List projectionTerms = new List(2);
projectionTerms.Add(distinctBinding.Variable);
projectionTerms.Add(selection);
// build projection type with initializer information
List properties = new List(2);
properties.Add(new EdmProperty(KeyColumnName, projectionTerms[0].ResultType));
properties.Add(new EdmProperty(GroupColumnName, projectionTerms[1].ResultType));
InitializerMetadata initializerMetadata = InitializerMetadata.CreateGroupingInitializer(
parent.EdmItemCollection, TypeSystem.GetElementType(call.Type));
RowType rowType = new RowType(properties, initializerMetadata);
TypeUsage rowTypeUsage = TypeUsage.Create(rowType);
CqtExpression topLevelProject = parent._commandTree.CreateProjectExpression(distinctBinding,
parent._commandTree.CreateNewInstanceExpression(rowTypeUsage, projectionTerms));
if (!hasElementSelector)
{
//Create optimized translation for the GroupBy - simple GroupBy template
DbGroupExpressionBinding groupByBinding;
CqtExpression newKeySelector = parent.TranslateLambda(keySelectorLinq, source, out groupByBinding);
DbGroupByTemplate groupByTemplate = new DbGroupByTemplate(groupByBinding);
groupByTemplate.GroupKeys.Add(new KeyValuePair(KeyColumnName, newKeySelector));
parent._groupByDefaultToOptimizedTranslationMap.Add(topLevelProject, groupByTemplate);
}
var result = topLevelProject;
// GroupBy may include a result selector; handle it
result = ProcessResultSelector(parent, call, sequenceMethod, topLevelProject, result);
return result;
}
private static DbExpression ProcessResultSelector(ExpressionConverter parent, MethodCallExpression call, SequenceMethod sequenceMethod, CqtExpression topLevelProject, DbExpression result)
{
// interpret result selector if needed
LambdaExpression resultSelectorLinqExpression = null;
if (sequenceMethod == SequenceMethod.GroupByResultSelector)
{
resultSelectorLinqExpression = parent.GetLambdaExpression(call, 2);
}
else if (sequenceMethod == SequenceMethod.GroupByElementSelectorResultSelector)
{
resultSelectorLinqExpression = parent.GetLambdaExpression(call, 3);
}
if (null != resultSelectorLinqExpression)
{
// selector maps (Key, Group) -> Result
// push bindings for key and group
DbExpressionBinding topLevelProjectBinding = parent._commandTree.CreateExpressionBinding(topLevelProject);
parent._variableNameToInputExpression.Add(topLevelProjectBinding.VariableName, topLevelProject);
DbPropertyExpression keyExpression = parent._commandTree.CreatePropertyExpression(
KeyColumnName, topLevelProjectBinding.Variable);
DbPropertyExpression groupExpression = parent._commandTree.CreatePropertyExpression(
GroupColumnName, topLevelProjectBinding.Variable);
parent._bindingContext.PushBindingScope(
new Binding(resultSelectorLinqExpression.Parameters[0], keyExpression),
new Binding(resultSelectorLinqExpression.Parameters[1], groupExpression));
// translate selector
CqtExpression resultSelector = parent.TranslateExpression(
resultSelectorLinqExpression.Body);
result = parent._commandTree.CreateProjectExpression(topLevelProjectBinding, resultSelector);
// see if the selector can be optimized
CqtExpression rewrittenExpression;
if (parent.TryRewrite(topLevelProject, topLevelProjectBinding, resultSelector, out rewrittenExpression))
{
result = rewrittenExpression;
}
parent._bindingContext.PopBindingScope();
}
return result;
}
internal override DbExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Fail("unreachable code");
return null;
}
}
private sealed class GroupJoinTranslator : SequenceMethodTranslator
{
internal GroupJoinTranslator()
: base(SequenceMethod.GroupJoin)
{
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
// o.GroupJoin(i, ok => outerKeySelector, ik => innerKeySelector, (o, i) => projection)
// -->
// SELECT projection(o, i)
// FROM (
// SELECT o, (SELECT i FROM i WHERE o.outerKeySelector = i.innerKeySelector) as i
// FROM o)
// translate inputs
CqtExpression outer = parent.TranslateSet(call.Arguments[0]);
CqtExpression inner = parent.TranslateSet(call.Arguments[1]);
// translate key selectors
DbExpressionBinding outerBinding;
DbExpressionBinding innerBinding;
LambdaExpression outerLambda = parent.GetLambdaExpression(call, 2);
LambdaExpression innerLambda = parent.GetLambdaExpression(call, 3);
CqtExpression outerSelector = parent.TranslateLambda(
outerLambda, outer, out outerBinding);
CqtExpression innerSelector = parent.TranslateLambda(
innerLambda, inner, out innerBinding);
// create innermost SELECT i FROM i WHERE ...
if (!TypeSemantics.IsEqualComparable(outerSelector.ResultType) ||
!TypeSemantics.IsEqualComparable(innerSelector.ResultType))
{
throw EntityUtil.NotSupported(System.Data.Entity.Strings.ELinq_UnsupportedKeySelector(call.Method.Name));
}
CqtExpression nestedCollection = parent.Filter(innerBinding,
parent.CreateEqualsExpression(outerSelector, innerSelector, EqualsPattern.PositiveNullEquality, outerLambda.Body.Type, innerLambda.Body.Type));
// create "join" SELECT o, (nestedCollection)
const string outerColumn = "o";
const string innerColumn = "i";
List> recordColumns = new List>(2);
recordColumns.Add(new KeyValuePair(outerColumn, outerBinding.Variable));
recordColumns.Add(new KeyValuePair(innerColumn, nestedCollection));
CqtExpression joinProjection = parent._commandTree.CreateNewRowExpression(recordColumns);
CqtExpression joinProject = parent._commandTree.CreateProjectExpression(outerBinding, joinProjection);
DbExpressionBinding joinProjectBinding = parent._commandTree.CreateExpressionBinding(joinProject);
// create property expressions for the outer and inner terms to bind to the parameters to the
// group join selector
CqtExpression outerProperty = parent._commandTree.CreatePropertyExpression(outerColumn,
joinProjectBinding.Variable);
CqtExpression innerProperty = parent._commandTree.CreatePropertyExpression(innerColumn,
joinProjectBinding.Variable);
// push the inner and the outer terms into the binding scope
LambdaExpression linqSelector = parent.GetLambdaExpression(call, 4);
parent._bindingContext.PushBindingScope(
new Binding(linqSelector.Parameters[0], outerProperty),
new Binding(linqSelector.Parameters[1], innerProperty));
// translate the selector
CqtExpression selectorProject = parent.TranslateExpression(linqSelector.Body);
// pop the binding scope
parent._bindingContext.PopBindingScope();
// create the selector projection
CqtExpression selector = parent._commandTree.CreateProjectExpression(joinProjectBinding, selectorProject);
return selector;
}
}
private abstract class OrderByTranslatorBase : OneLambdaTranslator
{
private readonly bool _ascending;
protected OrderByTranslatorBase(bool ascending, params SequenceMethod[] methods)
: base(methods)
{
_ascending = ascending;
}
protected override CqtExpression TranslateOneLambda(ExpressionConverter parent, DbExpressionBinding sourceBinding, CqtExpression lambda)
{
List keys = new List(1);
DbSortClause sortSpec = parent._commandTree.CreateSortClause(lambda, _ascending);
keys.Add(sortSpec);
DbSortExpression sort = parent.Sort(sourceBinding, keys);
return sort;
}
}
private sealed class OrderByTranslator : OrderByTranslatorBase
{
internal OrderByTranslator() : base(true, SequenceMethod.OrderBy) { }
}
private sealed class OrderByDescendingTranslator : OrderByTranslatorBase
{
internal OrderByDescendingTranslator() : base(false, SequenceMethod.OrderByDescending) { }
}
// Note: because we need to "push-down" the expression binding for ThenBy, this class
// does not inherit from OneLambdaTranslator, although it is similar.
private abstract class ThenByTranslatorBase : SequenceMethodTranslator
{
private readonly bool _ascending;
protected ThenByTranslatorBase(bool ascending, params SequenceMethod[] methods)
: base(methods)
{
_ascending = ascending;
}
internal override CqtExpression Translate(ExpressionConverter parent, MethodCallExpression call)
{
Debug.Assert(2 == call.Arguments.Count);
CqtExpression source = parent.TranslateSet(call.Arguments[0]);
if (DbExpressionKind.Sort != source.ExpressionKind)
{
throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.ELinq_ThenByDoesNotFollowOrderBy);
}
DbSortExpression sortExpression = (DbSortExpression)source;
// retrieve information about existing sort
DbExpressionBinding binding = sortExpression.Input;
// get information on new sort term
LambdaExpression lambdaExpression = parent.GetLambdaExpression(call, 1);
ParameterExpression parameter = lambdaExpression.Parameters[0];
// push-down the binding scope information and translate the new sort key
parent._bindingContext.PushBindingScope(new Binding(parameter, binding.Variable));
CqtExpression lambda = parent.TranslateExpression(lambdaExpression.Body);
parent._bindingContext.PopBindingScope();
// create a new sort expression
List keys = new List(sortExpression.SortOrder);
keys.Add(new DbSortClause(lambda, _ascending, null));
sortExpression = parent.Sort(binding, keys);
return sortExpression;
}
}
private sealed class ThenByTranslator : ThenByTranslatorBase
{
internal ThenByTranslator() : base(true, SequenceMethod.ThenBy) { }
}
private sealed class ThenByDescendingTranslator : ThenByTranslatorBase
{
internal ThenByDescendingTranslator() : base(false, SequenceMethod.ThenByDescending) { }
}
#endregion
}
}
}
// File provided for Reference Use Only by Microsoft Corporation (c) 2007.
Link Menu

This book is available now!
Buy at Amazon US or
Buy at Amazon UK
- MediaTimeline.cs
- Utils.cs
- StreamUpdate.cs
- SqlBinder.cs
- TableCell.cs
- DirectoryInfo.cs
- ProxyFragment.cs
- InternalSafeNativeMethods.cs
- StructuredTypeInfo.cs
- KeyValuePair.cs
- XamlStyleSerializer.cs
- XmlNotation.cs
- WorkflowValidationFailedException.cs
- TraceHandlerErrorFormatter.cs
- DesignParameter.cs
- XsltArgumentList.cs
- BamlTreeMap.cs
- FixedLineResult.cs
- HostingMessageProperty.cs
- CheckPair.cs
- InvalidContentTypeException.cs
- SafeFileMappingHandle.cs
- ManagementObjectCollection.cs
- TreeNodeBindingCollection.cs
- RadioButtonPopupAdapter.cs
- ServerType.cs
- RegionIterator.cs
- CDSCollectionETWBCLProvider.cs
- RotateTransform3D.cs
- SchemaElementDecl.cs
- IgnoreDeviceFilterElementCollection.cs
- BitmapMetadata.cs
- DataObject.cs
- InstanceDataCollectionCollection.cs
- Stack.cs
- MDIClient.cs
- DatatypeImplementation.cs
- ToolStripContentPanelRenderEventArgs.cs
- HMACSHA256.cs
- TextBreakpoint.cs
- WeakReadOnlyCollection.cs
- ObjectDataSourceEventArgs.cs
- ProxyHwnd.cs
- NotifyParentPropertyAttribute.cs
- GenericEnumConverter.cs
- CreateParams.cs
- StreamedWorkflowDefinitionContext.cs
- HistoryEventArgs.cs
- CodeMemberEvent.cs
- OleDbTransaction.cs
- FileChangesMonitor.cs
- WmfPlaceableFileHeader.cs
- CopyNodeSetAction.cs
- InertiaTranslationBehavior.cs
- UnicastIPAddressInformationCollection.cs
- CustomSignedXml.cs
- CategoryNameCollection.cs
- QueueNameHelper.cs
- CurrencyWrapper.cs
- Region.cs
- BaseContextMenu.cs
- Pair.cs
- SelectionListComponentEditor.cs
- XamlTreeBuilder.cs
- Variant.cs
- EndpointBehaviorElementCollection.cs
- AssemblyHash.cs
- DataObjectFieldAttribute.cs
- OleAutBinder.cs
- NativeMethods.cs
- BamlResourceContent.cs
- XmlAttributeOverrides.cs
- XamlPointCollectionSerializer.cs
- ObjectDataSourceSelectingEventArgs.cs
- MobileUserControl.cs
- Optimizer.cs
- PropertyPathConverter.cs
- Scripts.cs
- XPathQueryGenerator.cs
- StringUtil.cs
- AssociationSetEnd.cs
- GrammarBuilderDictation.cs
- WmlPanelAdapter.cs
- HttpModulesSection.cs
- ErrorLog.cs
- CodeExpressionStatement.cs
- XamlBrushSerializer.cs
- Container.cs
- ScrollBar.cs
- TextTreeFixupNode.cs
- XmlQueryCardinality.cs
- MobileControlBuilder.cs
- PropertyMapper.cs
- ObjectTag.cs
- Context.cs
- TypedServiceChannelBuilder.cs
- MetadataCacheItem.cs
- RegexInterpreter.cs
- StorageBasedPackageProperties.cs
- EncodingDataItem.cs