using System;
using System.Reflection;
using System.Xml;
using System.Collections;
using System.Collections.Specialized;
using System.Diagnostics;
using AnticipatingMinds.Genesis.CodeDOM;
using AnticipatingMinds.CommonUIControls;
using AnticipatingMinds.Genesis.CodeDOM.Utilities;
using AnticipatingMinds.Genesis.KnowledgeManagement;
using AnticipatingMinds.Genesis.Effectors;
namespace AnticipatingMinds.KnowledgePack.Threading{
public class ProtectStaticDataRuleTemplate : RuleTemplate
{
public ProtectStaticDataRuleTemplate(string templateId, IDictionary templateProperties):base(templateId,templateProperties)
{
KnowledgePackLicense.Validate();
}
public override string Name
{
get
{
return ResourceManager.GetLocalizedString("ProtectStaticDataRuleTemplate|Name");
}
}
public override string Description
{
get
{
return ResourceManager.GetLocalizedString("ProtectStaticDataRuleTemplate|Description");
}
}
public override Rule CreateRule(string ruleId, IDictionary ruleProperties)
{
// Todo: Add multiple rule types here from XML properties
// probably a case statement
return new ProtectStaticDataRule(ruleId,Id,ruleProperties);
}
public override AnticipatingMinds.CommonUIControls.PropertyPage[] GetRulePropertiesUI(IDictionary ruleProperties)
{
AnticipatingMinds.CommonUIControls.PropertyPage[] pages = new AnticipatingMinds.CommonUIControls.PropertyPage[1];
pages[0] = new ProtectedStaticDataUI(ruleProperties);
return pages;
}
}
class ProtectStaticDataRule : Rule
{
#region Rule Property
/// <summary>
/// DoNotEnforceTypeList - list user defenitions types for not block
/// </summary>
public new class PropertyName : Rule.PropertyName
{
public const string DoNotEnforceTypeList = "DoNotEnforceTypeList";
}
public ArrayList DoNotEnforceTypeList
{
get{ return GetProperty(PropertyName.DoNotEnforceTypeList, new ArrayList()) as ArrayList; }
set{ SetProperty(PropertyName.DoNotEnforceTypeList, value);}
}
#endregion
internal ProtectStaticDataRule(string id, string templateId, IDictionary ruleProperties) : base(id,templateId,ruleProperties)
{
}
class RuleData
{
/// <summary>
/// Reference to analyzed type.
/// </summary>
public CodeClassDeclaration AnalyzedType;
/// <summary>
/// Type static fileds.
/// </summary>
public IDictionary StaticFields = new Hashtable();
/// <summary>
/// Fileds values of which has been changed or could have been changed inside functional call.
/// </summary>
public StringCollection ModifiedStaticFileds = new StringCollection();
/// <summary>
/// Each entry of this array is collection of codeelements that are references to the
/// static fields.
/// </summary>
public ArrayList BlocksOfReferencesToStaticFields = new ArrayList();
public Stack StatementsScope = new Stack();
}
class UsageConstants
{
public const string IsInterlocked = "IsInterlocked";
public const string IsIncrement = "IsIncrement";
public const string IsDecrement = "IsDecrement";
public const string IsLocked = "IsLocked";
public const string IsLockedAlone = "IsLockedAlone";
public const string LockExpression = "LockExpression";
}
class ViolationType
{
public const string LockedInterlockedOperation = "LockedInterlockedOperation";//do not use lock for interlocked stuff
public const string InterlockableUnprotectedReference = "InterlockableUnprotectedReference";//Use interlock to protect
public const string LockedIncrementOrDecrement = "LockedIncrementOrDecrement"; //Shall use interlocked
public const string LockableUnprotectedReference = "LockableUnprotectedReference"; //Shall use lock
public const string LockableUnprotectedReferencesBlock = "LockableUnprotectedReferencesBlock"; //Shall use lock
}
/// <summary>
/// Read static fields into a dictionary. Key - field name, associated value is fields type.
/// different aliases of int, long & double are changed to long & double.
/// </summary>
/// <param name="classDeclaration"></param>
/// <param name="ruleData"></param>
private void InitializeRuleDataWithTypeStaticMembers(CodeClassDeclaration classDeclaration,RuleData ruleData)
{
foreach(CodeTypeMemberDeclaration typeMember in classDeclaration.Members)
{
if(typeMember is CodeTypeFieldDeclaration)
{
if((typeMember.Modifiers & CodeTypeMemberDeclaration.MemberDeclarationModifiers.Static) != 0)
{
#region Check for safe thread type definition in rule
string fieldTypeName = (typeMember as CodeTypeFieldDeclaration).DeclarationType.TypeInfo == null ? (typeMember as CodeTypeFieldDeclaration).DeclarationType.TypeName : (typeMember as CodeTypeFieldDeclaration).DeclarationType.TypeInfo.FullName;
if(DoNotEnforceTypeList.Contains(fieldTypeName))
continue;
#endregion
CodeTypeFieldDeclaration fieldDeclaration = typeMember as CodeTypeFieldDeclaration;
foreach(CodeVariableDeclarationMember varDeclaration in fieldDeclaration.DeclaredFields)
{
switch(fieldDeclaration.DeclarationType.TypeName)
{
case "System.Int32":
case "Int32":
{
ruleData.StaticFields.Add(varDeclaration.Name,"int");
break;
}
case "System.Int64":
case "Int64":
{
ruleData.StaticFields.Add(varDeclaration.Name,"long");
break;
}
case "System.Double":
case "Double":
{
ruleData.StaticFields.Add(varDeclaration.Name,"double");
break;
}
default:
{
if(fieldDeclaration.DeclarationType.TypeInfo != null)
ruleData.StaticFields.Add(varDeclaration.Name,fieldDeclaration.DeclarationType.TypeInfo.FullName);
else
ruleData.StaticFields.Add(varDeclaration.Name,fieldDeclaration.DeclarationType.TypeName);
break;
}
}
}
}
}
}
}
private CodeDomWalker.WalkerCallbackReturn GetReferencesToStaticFields(CodeElement codeElement, CodeDomWalker.CallBackNotificationType notificationType, CodeDomWalkerContext walkerContext,object applicationData)
{
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElement)
{
if(codeElement is CodeTypeMemberDeclaration)
if(!IsCodeElementInRuleApplicabilityScope(codeElement as CodeTypeMemberDeclaration))
return CodeDomWalker.WalkerCallbackReturn.NextSibling;
//Do not analyze static constructors
if(codeElement is CodeTypeConstructorDeclaration)
{
CodeTypeConstructorDeclaration constructorDeclaration = codeElement as CodeTypeConstructorDeclaration;
if((constructorDeclaration.Modifiers & CodeTypeMemberDeclaration.MemberDeclarationModifiers.Static) != 0)
return CodeDomWalker.WalkerCallbackReturn.NextSibling;
}
}
if(codeElement is CodeStatement)
{
RuleData ruleData = applicationData as RuleData;
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElementChildrenStarted)
{
if(ruleData.StatementsScope.Count == 0)
{
//If this is first statement (no parrents meaning that statement is the first atement of
//type member body or property getter or setter
ruleData.StatementsScope.Push(new CodeElementCollection());
}
else
{
ruleData.StatementsScope.Push(ruleData.StatementsScope.Peek());
}
}
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElementChildrenFinished)
{
CodeElementCollection referencesCollection = ruleData.StatementsScope.Pop() as CodeElementCollection;
if(ruleData.StatementsScope.Count == 0)
{
//It was last scope hence block of statements has ended - add the collection of references to
//the array of references blocks
if(referencesCollection.Count != 0)
{
ruleData.BlocksOfReferencesToStaticFields.Add(referencesCollection);
}
}
}
}
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElement)
{
if(codeElement is CodeNamedReferenceExpression && notificationType == CodeDomWalker.CallBackNotificationType.OnElement)
{
RuleData ruleData = applicationData as RuleData;
CodeNamedReferenceExpression namedReference = codeElement as CodeNamedReferenceExpression;
if(ruleData.StaticFields.Contains(namedReference.Name))
{
string leftReference = "";
if(namedReference.TargetObject is CodeNamedReferenceExpression)
leftReference = CodeNamedReferenceExpressionUtils.GetTargetReferencesAsString(namedReference.TargetObject as CodeNamedReferenceExpression);
if(namedReference.TargetObject is CodeTypeReferenceExpression)
leftReference = (namedReference.TargetObject as CodeTypeReferenceExpression).ReferencedType.TypeName;
if(leftReference.Length == 0 ||
leftReference == ruleData.AnalyzedType.FullName ||
leftReference == ruleData.AnalyzedType.Name)
{
//The stack count is only equals 0 when we process
//nested constructors that accapt static variables.
//We cannot do anything about it anyway - so just ignore it.
if(ruleData.StatementsScope.Count != 0)
(ruleData.StatementsScope.Peek() as CodeElementCollection).Add(codeElement as CodeElement);
return CodeDomWalker.WalkerCallbackReturn.Next;
}
}
}
}
return CodeDomWalker.WalkerCallbackReturn.Next;
}
/// <summary>
/// gets the closes common container for a set of code elements.
/// </summary>
/// <param name="codeElements"></param>
/// <returns></returns>
/// <summary>
/// Returns true if variable or field or anything referenced by codeElement is
/// updated by assignment or call to a function with ref or out params, or inc,dec
/// unary operations
/// </summary>
/// <param name="codeElement"></param>
/// <returns></returns>
public bool IsElementValueUpdated(CodeElement codeElement)
{
CodeElementTrace trace = CodeElementTrace.GetCodeElementTrace(codeElement);
//there are 3 things that can change the value of field or variable:
//1) Assignment
//2) Call to a function when argument is out or ref
//3) Unary inc or decr.
//4) A method on referenced object is called. It can change object state
//so we consider it as value update (Think of Hashtable.Add, Array.Insert and etc...)
for(int traceFrameIndex = trace.Trace.Length-1;traceFrameIndex>=0; traceFrameIndex--)
{
//Is it a method invokation from our reference? If yes
//consider object value updated.
if(trace.Trace[traceFrameIndex] is CodeMethodInvokeExpression)
{
CodeNamedReferenceExpression methodReferenceExpression = (trace.Trace[traceFrameIndex] as CodeMethodInvokeExpression).MethodReferenceExpression as CodeNamedReferenceExpression;
if(methodReferenceExpression != null && methodReferenceExpression.TargetObject == codeElement)
return true;
}
if(trace.Trace[traceFrameIndex] is CodeArgument)
{
CodeArgument codeArgument = trace.Trace[traceFrameIndex] as CodeArgument;
if(((codeArgument.Modifier & CodeArgument.CodeArgumentModifier.Ref) != 0) ||
((codeArgument.Modifier & CodeArgument.CodeArgumentModifier.Out) != 0))
{
return true;
}
}
if(trace.Trace[traceFrameIndex] is CodeUnaryExpression)
{
CodeUnaryExpression expression = trace.Trace[traceFrameIndex] as CodeUnaryExpression;
if(expression.Operator == CodeUnaryOperatorType.PostfixDecrement ||
expression.Operator == CodeUnaryOperatorType.PostfixIncrement ||
expression.Operator == CodeUnaryOperatorType.PrefixIncrement ||
expression.Operator == CodeUnaryOperatorType.PrefixDecrement)
{
return true;
}
}
if(trace.Trace[traceFrameIndex] is CodeBinaryExpression)
{
CodeBinaryExpression expression = trace.Trace[traceFrameIndex] as CodeBinaryExpression;
if(expression.Operator == CodeBinaryOperatorType.AdditionAssignment ||
expression.Operator == CodeBinaryOperatorType.Assign ||
expression.Operator == CodeBinaryOperatorType.BitwiseAndAssignment ||
expression.Operator == CodeBinaryOperatorType.BitwiseOrAssignment ||
expression.Operator == CodeBinaryOperatorType.DivisionAssignment ||
expression.Operator == CodeBinaryOperatorType.ExclusiveOrAssignment ||
expression.Operator == CodeBinaryOperatorType.LeftShiftAssignment ||
expression.Operator == CodeBinaryOperatorType.ModulusAssignment ||
expression.Operator == CodeBinaryOperatorType.MultiplicationAssignment ||
expression.Operator == CodeBinaryOperatorType.RightShiftAssignment ||
expression.Operator == CodeBinaryOperatorType.SubtractionAssignment)
{
//asignments must be on the left part of expression
if(trace.Trace[traceFrameIndex+1] == expression.LeftOperand)
{
return true;
}
}
}
}
return false;
}
/// <summary>
/// remove all read access references to all fields & types but long & double.
/// </summary>
/// <param name="ruleData"></param>
private void PruneReferencesThatDoesnotHaveToBeProtected(RuleData ruleData)
{
ArrayList referenceBlocksToDelete = new ArrayList();
foreach(CodeElementCollection referencesBlock in ruleData.BlocksOfReferencesToStaticFields)
{
ArrayList referencesToDelete = new ArrayList();
foreach(CodeNamedReferenceExpression reference in referencesBlock)
{
//Get rid of references to immutable fields uless it is is refenrence in foreach statement
//which always have to be protected
if(!ruleData.ModifiedStaticFileds.Contains(reference.Name) && !((reference.Parent is CodeForEachStatement) && (reference.Parent as CodeForEachStatement).IteratedCollection == reference))
{
referencesToDelete.Add(reference);
continue;
}
}
foreach(CodeElement referenceToDelete in referencesToDelete)
referencesBlock.Remove(referenceToDelete);
if(referencesBlock.Count == 0)
referenceBlocksToDelete.Add(referencesBlock);
}
foreach(CodeElementCollection referencesBlockToDelete in referenceBlocksToDelete)
ruleData.BlocksOfReferencesToStaticFields.Remove(referencesBlockToDelete);
}
//Populate ruleData.ModifiedStaticFileds with the names of fields that has been
//modified
private void DiscoverModifiedStaticFields(RuleData ruleData)
{
foreach(CodeElementCollection referencesBlock in ruleData.BlocksOfReferencesToStaticFields)
{
foreach(CodeNamedReferenceExpression referenceExpression in referencesBlock)
{
if(ruleData.ModifiedStaticFileds.Contains(referenceExpression.Name))
continue;
if(IsElementValueUpdated(referenceExpression))
ruleData.ModifiedStaticFileds.Add(referenceExpression.Name);
}
}
}
/// <summary>
/// Main entry point for rule analysis
/// </summary>
/// <param name="codeElement"></param>
/// <param name="cancelAnalysisEvent"></param>
/// <returns></returns>
public override RuleViolation[] Analyze(object codeElement, System.Threading.ManualResetEvent cancelAnalysisEvent)
{
CodeClassDeclaration classDeclaration = (CodeClassDeclaration)codeElement;
RuleData ruleData = new RuleData();
ruleData.AnalyzedType = classDeclaration;
InitializeRuleDataWithTypeStaticMembers(classDeclaration,ruleData);
CodeDomWalker.WalkCodeElement(codeElement as CodeElement,new CodeDomWalker.WalkerCallback(GetReferencesToStaticFields),ruleData);
if(ruleData.BlocksOfReferencesToStaticFields.Count == 0)
return noViolations;
return AnalyzeStaticReferences(ruleData).ToArray();
}
/// <summary>
/// Analyze blocks of references & allocate violations if found.
/// </summary>
/// <param name="ruleData"></param>
/// <returns></returns>
private RuleViolationCollection AnalyzeStaticReferences(RuleData ruleData)
{
DiscoverModifiedStaticFields(ruleData);
PruneReferencesThatDoesnotHaveToBeProtected(ruleData);
RuleViolationCollection violations = new RuleViolationCollection();
foreach(CodeElementCollection referencesBlock in ruleData.BlocksOfReferencesToStaticFields)
violations.AddRange(AnalyzeMultipleReferences(referencesBlock,ruleData));
return violations;
}
private CodeExpression GetTypeMemberSynchAttributeLockExpression(CodeTypeMemberDeclaration typeMember, CodeAttributeCollection attributes)
{
foreach(CodeAttribute attribute in attributes)
{
if(String.Compare(attribute.AttributeType.TypeName,"MethodImplAttribute",true) == 0 ||
String.Compare(attribute.AttributeType.TypeName,"MethodImpl",true) == 0 ||
String.Compare(attribute.AttributeType.TypeName,"System.Runtime.CompilerServices.MethodImplAttribute",true) == 0 )
{
foreach(CodeArgument argument in attribute.Arguments)
if(CodeNamedReferenceExpressionUtils.FindReferenceTo(argument.Value,"Synchronized").Count != 0)
{
if((typeMember.Modifiers & CodeTypeMemberDeclaration.MemberDeclarationModifiers.Static) != 0)
{
return new CodeUnaryExpression(CodeUnaryOperatorType.TypeOf,new CodeTypeReferenceExpression(typeMember.DeclaringType.FullName));
}
else
{
return new CodeThisReferenceExpression();
}
}
}
}
return null;
}
private CodeExpression GetStaticLockExpression(CodeExpression lockExpression,RuleData ruleData)
{
while(lockExpression is CodeParenthesizedExpression)
lockExpression = (lockExpression as CodeParenthesizedExpression).Expression;
if(lockExpression is CodeUnaryExpression)
{
CodeUnaryExpression unaryExpression = lockExpression as CodeUnaryExpression;
if(unaryExpression.Operator == CodeUnaryOperatorType.TypeOf)
{
lockExpression = unaryExpression.Operand;
while(lockExpression is CodeParenthesizedExpression)
lockExpression = (lockExpression as CodeParenthesizedExpression).Expression;
return lockExpression;
}
}
if(lockExpression is CodeNamedReferenceExpression)
{
CodeNamedReferenceExpression namedReferenceExpression = lockExpression as CodeNamedReferenceExpression;
if(ruleData.StaticFields.Contains(namedReferenceExpression.Name))
return lockExpression;
}
if(lockExpression is CodePropertyReferenceExpression
&& (lockExpression as CodePropertyReferenceExpression).TargetObject is CodeNamedReferenceExpression
)
{
CodePropertyReferenceExpression property = lockExpression as CodePropertyReferenceExpression;
CodeNamedReferenceExpression namedReferenceExpression = (lockExpression as CodePropertyReferenceExpression).TargetObject as CodeNamedReferenceExpression;
if(property.PropertyInfo != null
&& property.PropertyInfo.Name == "SyncRoot"
&& ruleData.StaticFields.Contains(namedReferenceExpression.Name))
{
return lockExpression;
}
}
return null;
}
/// <summary>
/// This function discovers and associated with codeEleemnt the followig usage constants
/// UsageConstants.IsLocked
/// UsageConstants.LockExpression
/// UsageConstants.IsLockedAlone
/// </summary>
/// <param name="codeElement"></param>
/// <param name="trace"></param>
/// <param name="ruleData"></param>
private void DiscoverStaticLockUsage(CodeElement codeElement,CodeElement[] trace,RuleData ruleData)
{
//Initialize associated values first
codeElement.ApplicationData[UsageConstants.IsLocked] = false;
codeElement.ApplicationData[UsageConstants.LockExpression] = null;
codeElement.ApplicationData[UsageConstants.IsLockedAlone] = false;
//Walk stack trace when necessary get collection of statements and
//see if there was a lock and unlock whle executing those statements
CodeExpressionCollection aquiredLocks = new CodeExpressionCollection();
CodeLockStatement lockStatement = null;
for(int traceFrameIndex = 0; traceFrameIndex < trace.Length; traceFrameIndex++)
{
CodeElement traceFrame = trace[traceFrameIndex];
if(traceFrame is CodeTypeMethodDeclaration)
{
CodeTypeMethodDeclaration methodDeclaration = traceFrame as CodeTypeMethodDeclaration;
CodeExpression attributedLockExpression = GetTypeMemberSynchAttributeLockExpression(methodDeclaration,methodDeclaration.Attributes);
if(attributedLockExpression != null && !(attributedLockExpression is CodeThisReferenceExpression))
aquiredLocks.Add(attributedLockExpression);
}
if(traceFrame is CodeTypePropertyDeclaration)
{
CodeTypePropertyDeclaration propertyDeclaration = traceFrame as CodeTypePropertyDeclaration;
CodeExpression attributedLockExpression = GetTypeMemberSynchAttributeLockExpression(propertyDeclaration,propertyDeclaration.Attributes);
if(attributedLockExpression != null && !(attributedLockExpression is CodeThisReferenceExpression))
aquiredLocks.Add(attributedLockExpression);
if(propertyDeclaration.GetAccessorStatements == trace[traceFrameIndex+1])
{
attributedLockExpression = GetTypeMemberSynchAttributeLockExpression(propertyDeclaration,propertyDeclaration.GetAccessorAttributes);
if(attributedLockExpression != null && !(attributedLockExpression is CodeThisReferenceExpression))
aquiredLocks.Add(attributedLockExpression);
}
if(propertyDeclaration.SetAccessorStatements == trace[traceFrameIndex+1])
{
attributedLockExpression = GetTypeMemberSynchAttributeLockExpression(propertyDeclaration,propertyDeclaration.SetAccessorAttributes);
if(attributedLockExpression != null && !(attributedLockExpression is CodeThisReferenceExpression))
aquiredLocks.Add(attributedLockExpression);
}
}
if(traceFrame is CodeLockStatement)
{
lockStatement = traceFrame as CodeLockStatement;
CodeExpression lockExpression = GetStaticLockExpression(lockStatement.Expression,ruleData);
if(lockExpression != null)
aquiredLocks.Add(lockExpression);
continue;
}
if(traceFrame is CodeStatement)
{
CollectionBase containerCollection = null;
CodeElementUtils.GetElementContainerInfo(trace[traceFrameIndex-1],traceFrame,out containerCollection);
if(containerCollection is CodeStatementCollection)
{
//Let's see if there were Monitor.Enter Statements in this collection
//before our element
foreach(CodeStatement statement in containerCollection)
{
if(statement == traceFrame)
break;
MonitorInvokeData lockData = GetMonitorInvokeData(statement);
if(lockData != null)
{
if(lockData.isEnter)
{
CodeExpression lockExpression = GetStaticLockExpression(lockData.lockExpression,ruleData);
if(lockExpression != null)
aquiredLocks.Add(lockExpression);
}
else
{
CodeExpression lockExpression = GetStaticLockExpression(lockData.lockExpression,ruleData);
if(lockExpression != null)
{
string lockSignature = string.Empty;
lockSignature = GetStaticLockSignature(lockExpression);
foreach(CodeExpression aquiredLockExpression in aquiredLocks)
{
if(lockSignature == GetStaticLockSignature(aquiredLockExpression))
{
aquiredLocks.Remove(aquiredLockExpression);
break;
}
}
}
}
}
}
}
}
}
//Let's figure out if statement is locked alone
bool isLockedAlone = false;
//We can only be locked alone if we are locked by a single lock
//otherwise we are not locked alone! Always! Forever! Until the end of time or death of lock!
if(aquiredLocks.Count == 1)
{
CodeStatement elementStatement = CodeStatementUtils.GetElementStatement(codeElement);
//If locked by lock statement test if we are locked alone
if(lockStatement != null)
{
CodeStatement lockedStatement = lockStatement.Statement;
while(lockedStatement is CodeStatementBlock && (lockedStatement as CodeStatementBlock).Statements.Count == 1)
lockedStatement = (lockedStatement as CodeStatementBlock).Statements[0];
if(lockedStatement == elementStatement)
isLockedAlone = true;
}
}
codeElement.ApplicationData[UsageConstants.IsLocked] = aquiredLocks.Count != 0;
codeElement.ApplicationData[UsageConstants.LockExpression] = aquiredLocks;
codeElement.ApplicationData[UsageConstants.IsLockedAlone] = isLockedAlone;
}
private class MonitorInvokeData
{
public bool isEnter = false;
public CodeExpression lockExpression = null;
}
private MonitorInvokeData GetMonitorInvokeData(CodeStatement codeStatement)
{
MonitorInvokeData data = new MonitorInvokeData();
CodeDomWalker.WalkCodeElement(codeStatement,new CodeDomWalker.WalkerCallback(CollectReferencesToStatmentMonitorCalls),data);
if(data.lockExpression == null)
return null;
else
return data;
}
private CodeDomWalker.WalkerCallbackReturn CollectReferencesToStatmentMonitorCalls(CodeElement codeElement, CodeDomWalker.CallBackNotificationType notificationType, CodeDomWalkerContext walkerContext,object applicationData)
{
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElement)
{
if(codeElement is CodeMethodInvokeExpression)
{
CodeMethodInvokeExpression methodInvokeExpression = codeElement as CodeMethodInvokeExpression;
if(methodInvokeExpression.MethodReferenceExpression is CodeMethodReferenceExpression)
{
CodeMethodReferenceExpression methodReferenceExpression = methodInvokeExpression.MethodReferenceExpression as CodeMethodReferenceExpression;
if(methodReferenceExpression.MethodInfo != null && methodReferenceExpression.MethodInfo.DeclaringTypeInfo != null)
{
if(methodReferenceExpression.MethodInfo.DeclaringTypeInfo.FullName == "System.Threading.Monitor" && methodReferenceExpression.MethodInfo.Name == "Enter")
{
CodeExpression lockExpression = (codeElement as CodeMethodInvokeExpression).Arguments[0].Value;
MonitorInvokeData data = applicationData as MonitorInvokeData;
data.lockExpression = lockExpression;
data.isEnter = true;
return CodeDomWalker.WalkerCallbackReturn.Cancel;
}
if(methodReferenceExpression.MethodInfo.DeclaringTypeInfo.FullName == "System.Threading.Monitor" && methodReferenceExpression.MethodInfo.Name == "Exit")
{
CodeExpression lockExpression = (codeElement as CodeMethodInvokeExpression).Arguments[0].Value;
MonitorInvokeData data = applicationData as MonitorInvokeData;
data.lockExpression = lockExpression;
data.isEnter = false;
return CodeDomWalker.WalkerCallbackReturn.Cancel;
}
}
}
}
}
return CodeDomWalker.WalkerCallbackReturn.Next;
}
private string GetStaticLockSignature(CodeExpression lockExpression)
{
while(lockExpression is CodeParenthesizedExpression)
lockExpression = (lockExpression as CodeParenthesizedExpression).Expression;
string lockSignature = string.Empty;
if(lockExpression is CodeUnaryExpression)
{
lockSignature = "typeof:";
lockExpression = (lockExpression as CodeUnaryExpression).Operand;
}
if(lockExpression is CodeNamedReferenceExpression)
lockSignature += CodeNamedReferenceExpressionUtils.GetTargetReferencesAsString(lockExpression as CodeNamedReferenceExpression);
if(lockExpression is CodeTypeReferenceExpression)
lockSignature += (lockExpression as CodeTypeReferenceExpression).ReferencedType.TypeName;
return lockSignature;
}
private void DiscoverIncDecUsage(CodeElement codeElement,CodeElement[] trace,RuleData ruleData)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = false;
//Inc/dec is only valid for ints and longs
if((ruleData.StaticFields[(codeElement as CodeNamedReferenceExpression).Name] as String) != "int" &&
(ruleData.StaticFields[(codeElement as CodeNamedReferenceExpression).Name] as String) != "long")
{
return;
}
//there are 4 ways to inc or dec a++;a += 1; a = a + 1; a = (a + 1); ++a;
//Get the very first statement -
CodeStatement codeStatement = null;
for(int traceFrameIndex = trace.Length-1;traceFrameIndex>=0; traceFrameIndex--)
{
if(trace[traceFrameIndex] is CodeStatement)
{
codeStatement = trace[traceFrameIndex] as CodeStatement;
break;
}
}
if(!(codeStatement is CodeExpressionStatement))
return;
CodeExpressionStatement expressionStatement = codeStatement as CodeExpressionStatement;
if(expressionStatement.Expression is CodeBinaryExpression)
{
CodeBinaryExpression binaryExpression = expressionStatement.Expression as CodeBinaryExpression;
if(binaryExpression.Operator != CodeBinaryOperatorType.Assign &&
binaryExpression.Operator != CodeBinaryOperatorType.SubtractionAssignment &&
binaryExpression.Operator != CodeBinaryOperatorType.AdditionAssignment)
{
return;
}
CodeExpression leftExpression = binaryExpression.LeftOperand;
CodeExpression rightExpression = binaryExpression.RightOperand;
//Get rid of unneeded parenthesis
//a = (a+1); ->
while(leftExpression is CodeParenthesizedExpression)
leftExpression = (leftExpression as CodeParenthesizedExpression).Expression;
while(rightExpression is CodeParenthesizedExpression)
rightExpression = (rightExpression as CodeParenthesizedExpression).Expression;
//Must be assignment to our element
if(leftExpression != codeElement)
return;
//Cases left: a=a+1; a+=1;
//test for: a+=1;
if(rightExpression is CodePrimitiveExpression)
{
CodePrimitiveExpression primitiveExpression = rightExpression as CodePrimitiveExpression;
if(primitiveExpression.Value.ToString() == "1")
{
if(binaryExpression.Operator == CodeBinaryOperatorType.SubtractionAssignment)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = true;
codeElement.ApplicationData[UsageConstants.IsIncrement] = false;
return;
}
else //This can only be += all other operators we have filtered above.
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = true;
return;
}
}
}
//Last case: a = a + 1;
CodeBinaryExpression rightArithmeticExpression = rightExpression as CodeBinaryExpression;
if(rightArithmeticExpression == null)
return;
leftExpression = rightArithmeticExpression.LeftOperand;
rightExpression = rightArithmeticExpression.RightOperand;
while(leftExpression is CodeParenthesizedExpression)
leftExpression = (leftExpression as CodeParenthesizedExpression).Expression;
while(rightExpression is CodeParenthesizedExpression)
rightExpression = (rightExpression as CodeParenthesizedExpression).Expression;
if(
(leftExpression is CodeNamedReferenceExpression) &&
(leftExpression as CodeNamedReferenceExpression).Name == (codeElement as CodeNamedReferenceExpression).Name &&
(rightExpression is CodePrimitiveExpression)&&
(rightExpression as CodePrimitiveExpression).Value.ToString() == "1")
{
if(rightArithmeticExpression.Operator == CodeBinaryOperatorType.Addition)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = true;
return;
}
if(rightArithmeticExpression.Operator == CodeBinaryOperatorType.Subtraction)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = true;
codeElement.ApplicationData[UsageConstants.IsIncrement] = false;
return;
}
}
}
if(expressionStatement.Expression is CodeUnaryExpression)
{
CodeUnaryExpression unaryExpression = expressionStatement.Expression as CodeUnaryExpression;
if(unaryExpression.Operator == CodeUnaryOperatorType.PostfixDecrement || unaryExpression.Operator == CodeUnaryOperatorType.PrefixDecrement)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = true;
codeElement.ApplicationData[UsageConstants.IsIncrement] = false;
return;
}
if(unaryExpression.Operator == CodeUnaryOperatorType.PostfixIncrement || unaryExpression.Operator == CodeUnaryOperatorType.PrefixIncrement)
{
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = true;
return;
}
return;
}
}
private void DiscoverInterlockedUsage(CodeElement codeElement,CodeElement[] trace,RuleData ruleData)
{
codeElement.ApplicationData[UsageConstants.IsInterlocked] = false;
for(int traceFrameIndex = trace.Length-1; traceFrameIndex >= 0; traceFrameIndex--)
{
CodeElement traceElement = trace[traceFrameIndex];
if(traceElement is CodeMethodInvokeExpression)
{
CodeMethodInvokeExpression methodInvokeExpression = traceElement as CodeMethodInvokeExpression;
CodeNamedReferenceExpression methodReferenceExpression = methodInvokeExpression.MethodReferenceExpression as CodeNamedReferenceExpression;
if(methodReferenceExpression == null)
break;
string methodName = methodReferenceExpression.Name;
string referencedTypeName;
if(methodReferenceExpression.TargetObject is CodeTypeReferenceExpression)
referencedTypeName = (methodReferenceExpression.TargetObject as CodeTypeReferenceExpression).ReferencedType.TypeName;
else
referencedTypeName = CodeNamedReferenceExpressionUtils.GetTargetReferencesAsString(methodReferenceExpression);
referencedTypeName = referencedTypeName.ToLower();
if(referencedTypeName.IndexOf("system.threading.interlocked") == 0 ||
referencedTypeName.IndexOf("threading.interlocked") == 0 ||
referencedTypeName.IndexOf("interlocked") == 0)
{
if(methodName.ToLower() == "increment")
{
codeElement.ApplicationData[UsageConstants.IsInterlocked] = true;
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = true;
return;
}
if(methodName.ToLower() == "decrement")
{
codeElement.ApplicationData[UsageConstants.IsInterlocked] = true;
codeElement.ApplicationData[UsageConstants.IsDecrement] = false;
codeElement.ApplicationData[UsageConstants.IsIncrement] = true;
return;
}
}
break;
}
}
return;
}
/// <summary>
/// This function discovers and associates with an instance information about how a particular element is used
/// </summary>
/// <param name="codeElement"></param>
private void DiscoverElementUsage(CodeElement codeElement,RuleData ruleData)
{
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(codeElement).Trace;
DiscoverStaticLockUsage(codeElement,trace,ruleData);
//The order of calls to the following functions is important as they share and initialize
//common usage slots (inc and dec).
DiscoverIncDecUsage(codeElement,trace,ruleData);
DiscoverInterlockedUsage(codeElement,trace,ruleData);
}
private bool ProcessIntAndLongRferences(CodeNamedReferenceExpression codeElement,RuleViolationCollection violations,RuleData ruleData)
{
string fieldType = ruleData.StaticFields[codeElement.Name] as string;
if(fieldType != "int" && fieldType != "long")
return false;
if( (bool)codeElement.ApplicationData[UsageConstants.IsInterlocked] &&
(bool)codeElement.ApplicationData[UsageConstants.IsLockedAlone])
{
// Raise violation
RuleViolation ruleViolation = new RuleViolation(this,codeElement);
ruleViolation.AddCorrection(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|LockedInterlockedOperation"),true,false);
ruleViolation.Severity = RuleViolation.ViolationSeverity.Error;
ruleViolation.Description = ResourceManager.GetLocalizedString("ProtectStaticDataRule|Violation|LockedInterlockedOperation",(codeElement as CodeNamedReferenceExpression).Name);
ruleViolation.ViolationData["ViolationType"] = ViolationType.LockedInterlockedOperation;
ruleViolation.ViolationData["TypeDeclaration"] = ruleData.AnalyzedType;
ruleViolation.ViolationData["ReferenceToProtect"] = codeElement;
violations.Add(ruleViolation);
return true;
}
if( (!(bool)codeElement.ApplicationData[UsageConstants.IsInterlocked]) &&
(bool) codeElement.ApplicationData[UsageConstants.IsLocked] &&
(bool) codeElement.ApplicationData[UsageConstants.IsLockedAlone] &&
((bool) codeElement.ApplicationData[UsageConstants.IsIncrement] ||
(bool) codeElement.ApplicationData[UsageConstants.IsDecrement])
)
{
RuleViolation ruleViolation = new RuleViolation(this,codeElement);
ruleViolation.AddCorrection(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|LockedIncrementOrDecrement"),true,false);
ruleViolation.Severity = RuleViolation.ViolationSeverity.Warning;
ruleViolation.Description = ResourceManager.GetLocalizedString("ProtectStaticDataRule|Violation|LockedIncrementOrDecrement",(codeElement as CodeNamedReferenceExpression).Name);
ruleViolation.ViolationData["ViolationType"] = ViolationType.LockedIncrementOrDecrement;
ruleViolation.ViolationData["TypeDeclaration"] = ruleData.AnalyzedType;
ruleViolation.ViolationData["ReferenceToProtect"] = codeElement;
violations.Add(ruleViolation);
return true;
}
if( (!(bool)codeElement.ApplicationData[UsageConstants.IsInterlocked]) &&
(!(bool) codeElement.ApplicationData[UsageConstants.IsLocked]) &&
((bool) codeElement.ApplicationData[UsageConstants.IsIncrement] ||
(bool) codeElement.ApplicationData[UsageConstants.IsDecrement])
)
{
RuleViolation ruleViolation = new RuleViolation(this,codeElement);
ruleViolation.AddCorrection(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|InterlockableUnprotectedReference"),true,false);
ruleViolation.Severity = RuleViolation.ViolationSeverity.Error;
ruleViolation.Description = ResourceManager.GetLocalizedString("ProtectStaticDataRule|Violation|InterlockableUnprotectedReference",(codeElement as CodeNamedReferenceExpression).Name);
ruleViolation.ViolationData["ViolationType"] = ViolationType.InterlockableUnprotectedReference;
ruleViolation.ViolationData["TypeDeclaration"] = ruleData.AnalyzedType;
ruleViolation.ViolationData["ReferenceToProtect"] = codeElement;
violations.Add(ruleViolation);
return true;
}
///If this element is not protected one way or another - return false and mark begining of new lock block
if((bool)codeElement.ApplicationData[UsageConstants.IsInterlocked] || (bool) codeElement.ApplicationData[UsageConstants.IsLocked])
return true;
else
return false;
}
private CodeExpressionCollection GetLockExpressions(CodeElementCollection referencesBlock,CodeTypeDeclaration typeDeclaration,RuleData ruleData)
{
CodeExpressionCollection lockExpressions = new CodeExpressionCollection();
StringCollection lockableFields = new StringCollection();
foreach(CodeElement referenceElement in referencesBlock)
{
if(referenceElement is CodeNamedReferenceExpression)
{
string fieldName = (referenceElement as CodeNamedReferenceExpression).Name;
string lockObjectName = fieldName + "Lock";
if(!lockableFields.Contains(lockObjectName))
{
lockableFields.Add(lockObjectName);
lockExpressions.Add(new CodeFieldReferenceExpression(null,lockObjectName));
//Check if field exists and if field is of type ICollection
if(ruleData.StaticFields.Contains(fieldName))
if(typeDeclaration.GetAssemblyTypeManager().DoesTypeImplementInterface(ruleData.StaticFields[fieldName] as string,"System.Collections.ICollection"))
{
CodeTypeCastExpression typeCast = new CodeTypeCastExpression();
typeCast.Expression = new CodeFieldReferenceExpression(null,fieldName);
typeCast.TargetType = new CodeTypeReference("System.Collections.ICollection");
CodePropertyReferenceExpression propertyReference = new CodePropertyReferenceExpression(new CodeParenthesizedExpression(typeCast),"SyncRoot");
lockExpressions.Add(propertyReference);
}
}
}
}
//first - get list of static objects of type object and initialization of type new object
//basicly these declarations
//[private|public|any|] static object [anyname] = new object();
foreach(CodeTypeMemberDeclaration typeMember in CodeTypeDeclarationUtils.GetTypeMembers(typeDeclaration))
{
if(typeMember is CodeTypeFieldDeclaration)
{
if((typeMember.Modifiers & CodeTypeMemberDeclaration.MemberDeclarationModifiers.Static) != 0)
{
CodeTypeFieldDeclaration fieldDeclaration = typeMember as CodeTypeFieldDeclaration;
if(fieldDeclaration.DeclarationType.TypeInfo == null)
continue;
if(fieldDeclaration.DeclarationType.TypeInfo.FullName != "System.Object")
continue;
foreach(CodeVariableDeclarationMember declaredField in fieldDeclaration.DeclaredFields)
{
if(declaredField.Initializer is CodeObjectCreateExpression)
{
CodeObjectCreateExpression fieldInitializer = declaredField.Initializer as CodeObjectCreateExpression;
if(fieldInitializer.CreateType.TypeInfo != null && fieldInitializer.CreateType.TypeInfo.FullName == "System.Object")
{
if(!lockableFields.Contains(declaredField.Name))
{
lockExpressions.Add(new CodeFieldReferenceExpression(null,declaredField.Name));
lockableFields.Add(declaredField.Name);
}
}
}
}
}
}
}
//add type of lock expression
lockExpressions.Add(new CodeUnaryExpression(CodeUnaryOperatorType.TypeOf,new CodeTypeReferenceExpression(typeDeclaration.Name)));
return lockExpressions;
}
private bool DoesReferenceNeedsProtection(CodeExpression reference)
{
if(reference == null)
return true; //Strange reference, but just in case protect it!
//If this is reference in foreach statement in coiterated collection target - than yes
//we need to protect it!
if(reference.Parent is CodeForEachStatement && (reference.Parent as CodeForEachStatement).IteratedCollection == reference)
return true;
//Get rid of reads of references to pointer types
//return instance;
//value = instance;
//if(instance ==
//but not:
//if(instance.Length ==
//return instance.MyProperty
//First let's see if reference is pure reference to the field
//not a instance.property
if((reference.Parent == null) || ((reference.Parent != null) && !(reference.Parent is CodeReferenceExpression)))
{
//See if reference type is actually reference type rather than valuetype
//Value type size differs from platform to platform
//no way to say if it is safe. f.e.: ints are safe and doubles are not on Intel
//but references are most likly safe as they are based of the int of the size of
//the bus(to address memory) hence - reading references is always safe
if(reference.ExpressionType != null && !reference.GetAssemblyTypeManager().IsTypeSubclassOf(reference.ExpressionType.FullName,"System.ValueType"))
{
//If values is pure and have not changed - remove it from block of used values
if(!IsElementValueUpdated(reference))
{
return false;
}
}
}
return true;
}
private RuleViolationCollection AnalyzeMultipleReferences(CodeElementCollection referencesBlock,RuleData ruleData)
{
RuleViolationCollection violations = new RuleViolationCollection();
//Discover element usage
foreach(CodeElement codeElement in referencesBlock)
DiscoverElementUsage(codeElement,ruleData);
bool isAllReferencesInterlockable = true;
//Analyze element usage and raise violations if appropriate.
//Try to process Interlockable data types first
//if the block to protect contained only interlocked references - great! Interlock them all,
//otherwise disregard interlock and process them as common block.
for(int elementIndex = 0; elementIndex < referencesBlock.Count;elementIndex++)
{
//this is tough. Consider this: i = i + 1; 2 references we need to understand that we can ignore
//the second one, for now, just read it, make sure it is not updated & it belongs to the same
//statement as first one.
if(!ProcessIntAndLongRferences(referencesBlock[elementIndex] as CodeNamedReferenceExpression,violations,ruleData))
{
//If we could not process the reference as int or long inc or dec, and the reference is updated
//hence in need for protection - then we canno protect the values with interlocked function.
if(elementIndex != 0 && !IsElementValueUpdated(referencesBlock[elementIndex]) && CodeStatementUtils.GetElementStatement(referencesBlock[elementIndex]) == CodeStatementUtils.GetElementStatement(referencesBlock[elementIndex-1]) &&
String.Compare((referencesBlock[elementIndex-1] as CodeNamedReferenceExpression).Name ,(referencesBlock[elementIndex] as CodeNamedReferenceExpression).Name,!referencesBlock[elementIndex].CompileUnit.IsCaseSensitive) == 0)
{
continue;
}
else
{
isAllReferencesInterlockable = false;
break;
}
}
}
//Is the whole block consists of interlocked references?
if(isAllReferencesInterlockable)
return violations;
//If there were violations added by processing int/long refernces clear them as they will
//be included in block violation
violations.Clear();
//Protect the whole block or just one reference in common lock!
if(referencesBlock.Count == 1)
{
//if forementioned reference does not need protection - return
if(!DoesReferenceNeedsProtection(referencesBlock[0] as CodeExpression))
return violations;
DiscoverStaticLockUsage(referencesBlock[0],CodeElementTrace.GetCodeElementTrace(referencesBlock[0]).Trace,ruleData);
if((bool)referencesBlock[0].ApplicationData[UsageConstants.IsLocked])
return violations;
RuleViolation ruleViolation = new RuleViolation(this,referencesBlock[0]);
foreach(CodeExpression lockExpression in GetLockExpressions(referencesBlock,ruleData.AnalyzedType,ruleData))
{
string lockCode = CSharpCodeGenerator.GenerateCode(lockExpression,0).ToString();
RuleViolationCorrection correction = ruleViolation.AddCorrection(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|LockableUnprotectedReference",lockCode),false,false);
correction.CorrectionData["LockExpression"] = lockExpression;
}
ruleViolation.Severity = RuleViolation.ViolationSeverity.Error;
ruleViolation.Description = ResourceManager.GetLocalizedString("ProtectStaticDataRule|Violation|LockableUnprotectedReference",(referencesBlock[0] as CodeNamedReferenceExpression).Name);
ruleViolation.ViolationData["ViolationType"] = ViolationType.LockableUnprotectedReference;
ruleViolation.ViolationData["TypeDeclaration"] = ruleData.AnalyzedType;
ruleViolation.ViolationData["ReferenceToProtect"] = referencesBlock[0];
violations.Add(ruleViolation);
}
else
{
//What is left is a block of references that might need to be locked
//Is block protected by common lock?
//Not everething inside the block needs to be protected
//references to comparison & just references do not need to be protedted
//so let's rebuild a block of references that has to be protected.
CodeElementCollection refencesNeededProtection = new CodeElementCollection();
foreach(CodeElement reference in referencesBlock)
if(DoesReferenceNeedsProtection(reference as CodeExpression))
refencesNeededProtection.Add(reference);
if(refencesNeededProtection.Count == 0)
return violations;
CodeElement commonContainer = CodeElementUtils.GetElementsCommonContainer(refencesNeededProtection);
DiscoverStaticLockUsage(commonContainer,CodeElementTrace.GetCodeElementTrace(commonContainer).Trace,ruleData);
if((bool)commonContainer.ApplicationData[UsageConstants.IsLocked])
return violations;
//No - there is no common lock for block of references.
CodeNamedReferenceExpression firstReference = referencesBlock[0] as CodeNamedReferenceExpression;
CodeNamedReferenceExpression lastReference = referencesBlock[referencesBlock.Count-1] as CodeNamedReferenceExpression;
RuleViolation ruleViolation = new RuleViolation(this,firstReference);
foreach(CodeExpression lockExpression in GetLockExpressions(referencesBlock,ruleData.AnalyzedType,ruleData))
{
string lockCode = CSharpCodeGenerator.GenerateCode(lockExpression,0).ToString();
RuleViolationCorrection correction = ruleViolation.AddCorrection(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|LockableUnprotectedReferencesBlock",lockCode),false,false);
correction.CorrectionData["LockExpression"] = lockExpression;
}
ruleViolation.Severity = RuleViolation.ViolationSeverity.Error;
ruleViolation.Description = ResourceManager.GetLocalizedString("ProtectStaticDataRule|Violation|LockableUnprotectedReferencesBlock",
firstReference.Name,
firstReference.SourcePosition.Line.ToString(),lastReference.Name,lastReference.SourcePosition.Line.ToString());
ruleViolation.ViolationData["ViolationType"] = ViolationType.LockableUnprotectedReferencesBlock;
ruleViolation.ViolationData["TypeDeclaration"] = ruleData.AnalyzedType;
//Calculate container as container to ALL the references - not only to once that need protection.
ruleViolation.ViolationData["ContainerToProtect"] = CodeElementUtils.GetElementsCommonContainer(referencesBlock);
ruleViolation.ViolationData["FirstReference"] = firstReference;
ruleViolation.ViolationData["LastReference"] = lastReference;
violations.Add(ruleViolation);
}
return violations;
}
private void ProtectInterlockableReference(CodeNamedReferenceExpression referenceToProtect,CodeEffector codeEffector)
{
CodeExpressionStatement invokeInterlockedMethodStatement = new CodeExpressionStatement();
CodeMethodInvokeExpression invokeExpression = new CodeMethodInvokeExpression();
invokeInterlockedMethodStatement.Expression = invokeExpression;
CodeNamedReferenceExpression memberReference = new CodeMethodReferenceExpression();
memberReference.TargetObject = new CodeTypeReferenceExpression("System.Threading.Interlocked");
if((bool)(referenceToProtect.ApplicationData[UsageConstants.IsIncrement]))
{
memberReference.Name = "Increment";
}
else
{
memberReference.Name = "Decrement";
}
CodeArgument argument = new CodeArgument();
invokeExpression.Arguments.Add(argument);
invokeExpression.MethodReferenceExpression = memberReference;
argument.Modifier = CodeArgument.CodeArgumentModifier.Ref;
argument.Value = referenceToProtect.Clone() as CodeExpression;
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(referenceToProtect).Trace;
for(int i = trace.Length-1; i >=0; i--)
{
if(trace[i] is CodeStatement)
{
codeEffector.ReplaceCodeElement(trace[i],invokeInterlockedMethodStatement);
break;
}
}
}
private void ReplaceLockWithInterlockedMethod(CodeNamedReferenceExpression referenceToProtect,CodeEffector codeEffector)
{
CodeExpressionStatement invokeInterlockedMethodStatement = new CodeExpressionStatement();
CodeMethodInvokeExpression invokeExpression = new CodeMethodInvokeExpression();
invokeInterlockedMethodStatement.Expression = invokeExpression;
CodeNamedReferenceExpression methodReferenceExpression = new CodeMethodReferenceExpression();
invokeExpression.MethodReferenceExpression = methodReferenceExpression;
methodReferenceExpression.TargetObject = new CodeTypeReferenceExpression("System.Threading.Interlocked");
if((bool)(referenceToProtect.ApplicationData[UsageConstants.IsIncrement]))
{
methodReferenceExpression.Name = "Increment";
}
else
{
methodReferenceExpression.Name = "Decrement";
}
CodeArgument argument = new CodeArgument();
invokeExpression.Arguments.Add(argument);
argument.Modifier = CodeArgument.CodeArgumentModifier.Ref;
argument.Value = referenceToProtect.Clone() as CodeExpression;
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(referenceToProtect).Trace;
for(int i = trace.Length-1; i >=0; i--)
{
if(trace[i] is CodeLockStatement)
{
codeEffector.ReplaceCodeElement(trace[i],invokeInterlockedMethodStatement);
break;
}
}
}
private void ProtectLockableReference(CodeElement referenceToProtect,CodeEffector codeEffector,RuleViolationCorrection correction)
{
CodeTypeReferenceExpression typeReferenceExpression = new CodeTypeReferenceExpression(); // do not put anything here yet.
CodeExpression lockExpression = correction.CorrectionData["LockExpression"] as CodeExpression;
CodeStatementBlock protectedBlock = new CodeStatementBlock();
CodeLockStatement lockStatement = new CodeLockStatement(lockExpression,protectedBlock);
AddLocableFieldMemberIfNecessary(correction,codeEffector);
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(referenceToProtect).Trace;
bool isFirstStatement = true;
CodeStatementBlock lastStatementsBlock = null;
int statementInTheBlockIndex = -1;
for(int i = trace.Length-1; i >=0; i--)
{
if(trace[i] is CodeStatementBlock)
{
lastStatementsBlock = trace[i] as CodeStatementBlock;
statementInTheBlockIndex = lastStatementsBlock.Statements.IndexOf(trace[i+1] as CodeStatement);
}
if(trace[i] is CodeStatement && isFirstStatement)
{
if(i != 0 && trace[i-1] is CodeUsingStatement)
continue;
protectedBlock.Statements.Add(trace[i] as CodeStatement);
codeEffector.ReplaceCodeElement(trace[i],lockStatement);
isFirstStatement = false;
}
if(trace[i] is CodeTypeDeclaration)
{
typeReferenceExpression.ReferencedType = new CodeTypeReference((trace[i] as CodeTypeDeclaration).Name);
break;
}
}
MoveVariableDeclarationsFromLockedBlock(lastStatementsBlock,statementInTheBlockIndex,protectedBlock.Statements,codeEffector);
}
private void AddLocableFieldMemberIfNecessary(RuleViolationCorrection correction,CodeEffector codeEffector)
{
//add lockable field if necessary
CodeExpression lockExpression = correction.CorrectionData["LockExpression"] as CodeExpression;
if(lockExpression is CodeFieldReferenceExpression)
{
bool needToAddField = true;
CodeFieldReferenceExpression fieldReference = lockExpression as CodeFieldReferenceExpression;
CodeTypeDeclaration typeDeclaration = correction.Violation.ViolationData["TypeDeclaration"] as CodeTypeDeclaration;
foreach(CodeTypeMemberDeclaration member in CodeTypeDeclarationUtils.GetTypeMembers(typeDeclaration))
{
if(member is CodeTypeFieldDeclaration)
{
foreach(CodeVariableDeclarationMember field in (member as CodeTypeFieldDeclaration).DeclaredFields)
if(string.Compare(field.Name,fieldReference.Name,!field.CompileUnit.IsCaseSensitive) == 0)
{
needToAddField = false;
break;
}
}
}
if(needToAddField)
{
CodeTypeFieldDeclaration fieldDeclaration = new CodeTypeFieldDeclaration();
fieldDeclaration.DeclarationType = new CodeTypeReference("object");
fieldDeclaration.DeclaredFields.Add(new CodeVariableDeclarationMember(fieldReference.Name,new CodeObjectCreateExpression(new CodeTypeReference("object"))));
fieldDeclaration.Modifiers = CodeTypeMemberDeclaration.MemberDeclarationModifiers.Private | CodeTypeMemberDeclaration.MemberDeclarationModifiers.Static;
CodeTypeMemberDeclarationCollection members = new CodeTypeMemberDeclarationCollection();
members.Add(fieldDeclaration);
codeEffector.AddCodeElements(members,typeDeclaration,CodeTypeDeclarationUtils.GetTypeMembers(typeDeclaration),CodeTypeDeclarationUtils.GetTypeMembers(typeDeclaration).Count);
}
}
}
private void MoveVariableDeclarationsFromLockedBlock(CodeStatementBlock declarationsContainer,int firstDeclarationIndex,CodeStatementCollection protectedStatements,CodeEffector codeEffector)
{
//If there were variable declarations lets move them outside the lock to keep variable visibility
//and replace them with assignments inside the lock.
CodeStatementCollection statementsToInsert = new CodeStatementCollection();
for(int protectedStatmentIndex = 0; protectedStatmentIndex < protectedStatements.Count;protectedStatmentIndex++)
{
if(!(protectedStatements[protectedStatmentIndex] is CodeVariableDeclarationStatement))
continue;
//Check that var declaration is not part of using statement
//if it is - ignore as its visibility canno exceed visibility of the using statement and hence
//cannot exceed visibility of lock
CodeElement[] trace = CodeElementUtils.GetTrace(protectedStatements[protectedStatmentIndex]);
if(trace.Length >= 2 && trace[trace.Length - 2] is CodeUsingStatement)
continue;
CodeVariableDeclarationStatement variableDeclaration = protectedStatements[protectedStatmentIndex] as CodeVariableDeclarationStatement;
//Get new declaration
CodeVariableDeclarationStatement newDeclarationStatement = new CodeVariableDeclarationStatement();
newDeclarationStatement.DeclarationType = variableDeclaration.DeclarationType;
newDeclarationStatement.IsConst = variableDeclaration.IsConst;
newDeclarationStatement.DocumentationComment = variableDeclaration.DocumentationComment;
newDeclarationStatement.Comment = variableDeclaration.Comment;
//convert all initializers into assignement statements and add them to the container
foreach(CodeVariableDeclarationMember member in variableDeclaration.DeclaredVariables)
{
CodeVariableDeclarationMember newMember = member.Clone() as CodeVariableDeclarationMember;
newMember.Initializer = null;
newDeclarationStatement.DeclaredVariables.Add(newMember);
}
statementsToInsert.Add(newDeclarationStatement);
//convert all initializers into assignement statements and add them to the container
foreach(CodeVariableDeclarationMember member in variableDeclaration.DeclaredVariables)
{
//Do not need any assignments if have not been initialized before.
if(member.Initializer == null)
continue;
CodeExpressionStatement assignmentStatement = new CodeExpressionStatement();
CodeBinaryExpression assignmentExpression = new CodeBinaryExpression();
assignmentStatement.Expression = assignmentExpression;
assignmentExpression.Operator = CodeBinaryOperatorType.Assign;
assignmentExpression.LeftOperand = new CodeVariableReferenceExpression(member.Name,newDeclarationStatement);
assignmentExpression.RightOperand = member.Initializer;
protectedStatements.Insert(protectedStatmentIndex,assignmentStatement);
}
protectedStatements.Remove(variableDeclaration);
}
if(statementsToInsert.Count != 0)
codeEffector.AddCodeElements(statementsToInsert,declarationsContainer,declarationsContainer.Statements,firstDeclarationIndex);
}
private CodeLockStatement ProtectLockableBlock(CodeElement blockContainer,CodeElement firstReference,CodeElement lastReference,CodeEffector codeEffector,RuleViolationCorrection correction)
{
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(blockContainer).Trace;
//First get some data such as:
//First common container as statement.
//parent statement of this statement
//type full name
CodeStatement containerStatement = null;
string typeName = string.Empty;
bool isFirstStatement = true;
for(int i = trace.Length-1; i >=0; i--)
{
if(trace[i] is CodeStatement && isFirstStatement)
{
if(i != 0 && trace[i-1] is CodeUsingStatement)
continue;
containerStatement = trace[i] as CodeStatement;
isFirstStatement = false;
}
if(trace[i] is CodeTypeDeclaration)
{
typeName = (trace[i] as CodeTypeDeclaration).Name;
break;
}
}
CodeExpression lockExpression = correction.CorrectionData["LockExpression"] as CodeExpression;
CodeLockStatement lockStatement = new CodeLockStatement(lockExpression);
AddLocableFieldMemberIfNecessary(correction,codeEffector);
if(containerStatement is CodeStatementBlock)
{
//Find firt statement in code block to protect
CodeStatementBlock lockedStatements = new CodeStatementBlock();
lockStatement.Statement = lockedStatements;
//Defines position of lock statement in containerStatement
int lockStatementIndex = -1;
//add only statements in between first and last reference
bool addingStatements = false;
foreach(CodeStatement statement in (containerStatement as CodeStatementBlock).Statements)
{
if(addingStatements)
{
CodeStatement clonedStatement = statement.Clone() as CodeStatement;
lockedStatements.Statements.Add(clonedStatement);
codeEffector.DeleteCodeElement(statement);
}
if(addingStatements && CodeElementUtils.IsDescendantOf(statement,lastReference))
addingStatements = false;
if(!addingStatements && CodeElementUtils.IsDescendantOf(statement,firstReference))
{
lockedStatements.Statements.Add(statement.Clone() as CodeStatement);
lockStatementIndex = (containerStatement as CodeStatementBlock).Statements.IndexOf(statement);
codeEffector.ReplaceCodeElement(statement,lockStatement);
addingStatements = true;
}
}
Debug.Assert(lockStatementIndex != -1,"Why protected statements do not exist? There must be at least one protected statement");
MoveVariableDeclarationsFromLockedBlock((containerStatement as CodeStatementBlock),lockStatementIndex,lockedStatements.Statements,codeEffector);
}
else
{
lockStatement.Statement = containerStatement as CodeEmbeddedStatement;
codeEffector.ReplaceCodeElement(containerStatement,lockStatement);
}
return lockStatement;
}
public override void Correct(RuleViolationCorrection[] requestedCorrections,CodeEffector codeEffector, System.Threading.ManualResetEvent cancelCorrectionEvent)
{
CodeElementCollection containersToOptimizeLocking = new CodeElementCollection();
codeEffector.BeginCodeChanges(ResourceManager.GetLocalizedString("ProtectStaticDataRule|Correction|Description"));
foreach(RuleViolationCorrection requestedCorrection in requestedCorrections)
{
if(requestedCorrection.Violation.IsFixed)
continue;
switch((string)requestedCorrection.Violation.ViolationData["ViolationType"])
{
case ViolationType.InterlockableUnprotectedReference:
{
ProtectInterlockableReference(requestedCorrection.Violation.ViolationData["ReferenceToProtect"] as CodeNamedReferenceExpression,codeEffector);
requestedCorrection.Violation.IsFixed = true;
break;
}
case ViolationType.LockableUnprotectedReference:
{
ProtectLockableReference(requestedCorrection.Violation.ViolationData["ReferenceToProtect"] as CodeNamedReferenceExpression,codeEffector,requestedCorrection);
requestedCorrection.Violation.IsFixed = true;
break;
}
case ViolationType.LockableUnprotectedReferencesBlock:
{
CodeElement containerToLock = requestedCorrection.Violation.ViolationData["ContainerToProtect"] as CodeElement;
CodeElement firstReference = requestedCorrection.Violation.ViolationData["FirstReference"] as CodeElement;
CodeElement lastReference = requestedCorrection.Violation.ViolationData["LastReference"] as CodeElement;
containersToOptimizeLocking.Add(ProtectLockableBlock(containerToLock,firstReference,lastReference,codeEffector,requestedCorrection));
requestedCorrection.Violation.IsFixed = true;
break;
}
case ViolationType.LockedIncrementOrDecrement:
{
ReplaceLockWithInterlockedMethod(requestedCorrection.Violation.ViolationData["ReferenceToProtect"] as CodeNamedReferenceExpression,codeEffector);
requestedCorrection.Violation.IsFixed = true;
break;
}
case ViolationType.LockedInterlockedOperation:
{
CodeElement codeElement = requestedCorrection.Violation.ViolationData["ReferenceToProtect"] as CodeElement;
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(codeElement).Trace;
for(int i = trace.Length-1; i >=0; i--)
{
if(trace[i] is CodeLockStatement)
{
CodeLockStatement lockStatement = trace[i] as CodeLockStatement;
CodeStatement lockNestedStatement = lockStatement.Statement;
while(lockNestedStatement is CodeStatementBlock && (lockNestedStatement as CodeStatementBlock).Statements.Count == 1)
lockNestedStatement = (lockNestedStatement as CodeStatementBlock).Statements[0];
codeEffector.ReplaceCodeElement(lockStatement,lockNestedStatement);
}
}
requestedCorrection.Violation.IsFixed = true;
break;
}
default:
{
Debug.Assert(false,"What is this unknow violation type?");
break;
}
}
}
//make first part of changes
codeEffector.PerformCodeChanges();
if(containersToOptimizeLocking.Count != 0)
{
foreach(CodeElement container in containersToOptimizeLocking)
CodeDomWalker.WalkCodeElement(container,new CodeDomWalker.WalkerCallback(ReplaceInterlocksWithUnaryOps),codeEffector);
}
codeEffector.CommitCodeChanges();
}
private CodeDomWalker.WalkerCallbackReturn ReplaceInterlocksWithUnaryOps(CodeElement codeElement, CodeDomWalker.CallBackNotificationType notificationType, CodeDomWalkerContext walkerContext,object applicationData)
{
if(notificationType == CodeDomWalker.CallBackNotificationType.OnElement)
{
if(codeElement.ApplicationData.Contains(UsageConstants.IsInterlocked) && (bool)codeElement.ApplicationData[UsageConstants.IsInterlocked])
{
CodeElement[] trace = CodeElementTrace.GetCodeElementTrace(codeElement).Trace;
for(int traceFrameIndex = trace.Length-1; traceFrameIndex >= 0; traceFrameIndex--)
{
CodeElement traceElement = trace[traceFrameIndex];
if(traceElement is CodeMethodInvokeExpression)
{
if(trace[traceFrameIndex-1] is CodeExpressionStatement)
{
CodeEffector codeEffector = applicationData as CodeEffector;
CodeExpressionStatement expressionStatement = new CodeExpressionStatement();
CodeUnaryExpression unaryExpression = new CodeUnaryExpression();
expressionStatement.Expression = unaryExpression;
if((bool)codeElement.ApplicationData[UsageConstants.IsIncrement])
unaryExpression.Operator = CodeUnaryOperatorType.PostfixIncrement;
else
unaryExpression.Operator = CodeUnaryOperatorType.PostfixDecrement;
unaryExpression.Operand = codeElement as CodeExpression;
codeEffector.ReplaceCodeElement(trace[traceFrameIndex-1],expressionStatement);
break;
}
}
}
}
}
return CodeDomWalker.WalkerCallbackReturn.Next;
}
public override Type[] TargetedCodeElements
{
get
{
return targetedCodeElements;
}
}
public override Type[] ApplicableApplicabilityScopeTypes
{
get
{
return applicableApplicabilityScopeTypes;
}
}
private static RuleViolation[] noViolations = new RuleViolation[0];
/// <summary>
/// This is used for declaring all code elements you want to analyze
/// Note that this started only declares one.
/// </summary>
private Type[] targetedCodeElements = {typeof(CodeClassDeclaration)};
private static Type[] applicableApplicabilityScopeTypes = {typeof(MethodApplicabilityScope),typeof(PropertyApplicabilityScope),typeof(EventApplicabilityScope),typeof(FileApplicabilityScope)};
}
}
|