Home > Software engineering >  Using Expression Trees to modify String and Child Class/Interface String Properties
Using Expression Trees to modify String and Child Class/Interface String Properties

Time:08-10

I am versed in the basics of Expression Trees, but I am having difficulty with this implementation:

I need to modify an instance of a class/interface, by finding every string property and encrypting it. I also need to recursively do the same to every "child" class/interface property, all the way down the rabbit hole. This includes IEnumerable<T> where T : class as well. I have the encryption sorted, but creating the expression try to do this for any T passed in is frankly beyond my understanding at this point.

I tried implementing this with reflection but performance quickly became an issue - here's what I've got so far:

Using the AggregateHelper and PropertyHelper classes from this post I was able to perform the basic string encryption:

Simple Test Classes:

public class SimpleEncryptionTest
{
   public string String1 { get; set; }
}

public class NestedClassEncryptionTest
{
   public string SomeString { get; set; }
   public SimpleEncryptionTest Inner { get; set; } 
}

public class ListClassEncryptionTest
{
    public List<string> StringList { get; set; }
}

The unit tests I'm using to verify results:

[TestMethod]
public void Encryption_works_on_a_simple_class()
{
    var testString = "this is only a test. If this had been a real emergency...";
    var sut = new SimpleEncryptionTest() { String1 = testString };
    sut.EncryptStringProperties();

    Assert.AreNotEqual(testString, sut.String1);
}

[TestMethod]
public void Round_trip_works_on_a_simple_class()
{
    var testString = "what string should I use?";
    var sut = new SimpleEncryptionTest() { String1 = testString };

    sut.EncryptStringProperties();
    sut.DecryptStringProperties();

    Assert.AreEqual(testString, sut.String1);
}

[TestMethod]
public void Round_trip_works_in_nested_class_scenario()
{
    var outerString = "what is your name?";
    var innerString = "Tony; what's your name?";

    var sut = new NestedClassEncryptionTest{
        SomeString = outerString,
        Inner = new SimpleEncryptionTest(){String1 = innerString }
    };

    sut.EncryptStringProperties();

    Assert.AreNotEqual(outerString, sut.SomeString);
    Assert.AreNotEqual(innerString, sut.Inner.String1);

    sut.DecryptStringProperties();

    Assert.AreEqual(outerString, sut.SomeString);
    Assert.AreEqual(innerString, sut.Inner.String1);
}

[TestMethod]
public void Round_trip_works_on_lists()
{
    var testone = "one";
    var testtwo = "two";
    var testStrings = new List<string>() { testone, testtwo };
    var sut = new ListClassEncryptionTest() { StringList = testStrings };

    sut.EncryptStringProperties();

    Assert.AreNotEqual(testone, sut.StringList[0]);
    Assert.AreNotEqual(testtwo, sut.StringList[1]);

    sut.DecryptStringProperties();

    Assert.AreEqual(testone, sut.StringList[0]);
    Assert.AreEqual(testtwo, sut.StringList[1]);
}

And the extension methods I'm using to Encrypt/Decrypt the values:

/// <summary>
/// Iterates through an Object's properties and encrypts any strings it finds, 
/// recursively iterating any class or interface Properties.
/// </summary>
/// <typeparam name="T">Any type</typeparam>
/// <param name="genericObj">The object to encrypt</param>
/// <returns>The original object, with its string values encrypted.</returns>
public static T EncryptStringProperties<T>(this T obj)
    where T : class
{
    return Crypt(obj, EncryptStringAction);
}

/// <summary>
/// Iterates through an Object's properties and decrypts any strings it finds, 
/// recursively iterating any class or interface Properties.
/// </summary>
/// <typeparam name="T">Any type</typeparam>
/// <param name="genericObj">The object to decrypt</param>
/// <returns>The original object, with its string values decrypted.</returns>
public static T DecryptStringProperties<T>(this T obj)
{
    return Crypt(obj, DecryptStringAction);
}

private static T Crypt<T>(T obj, Action<PropertyHelper<T, string>, T> action)
{
    var stringProperties = new AggregateHelper<T, string>();
    foreach (var property in stringProperties.Properties)
    {
        var propertyHelper = stringProperties[property];
        action(propertyHelper, obj);
    }

    // how do I find the nested classes, interfaces and lists
    // using Expression Trees to feed them through the same processing?

    return obj;
}

private static void EncryptStringAction<T>(PropertyHelper<T, string> prop, T genericObj)
{
    var plainTextValue = (string)prop.GetValue(genericObj);
    prop.SetValue(genericObj, plainTextValue.ToEncryptedString());
}

private static void DecryptStringAction<T>(PropertyHelper<T, string> prop, T genericObj)
{
    var encryptedValue = (string)prop.GetValue(genericObj);
    prop.SetValue(genericObj, encryptedValue.ToDecryptedString());
}

This works great as far as it goes; it encrypts string properties on any object, but I need to take it further - I need some way to recursively go "down the rabbit hole" - to create an AggregateHelper that selects properties that are instance objects (classes or interfaces) and feed those through the .EncryptStringProperties() extension method, as well as handling any IEnumerable with string values.

CodePudding user response:

Try the following extension. It generates replacement Expression Tree for each object and compiles delegates for fast string replacement. All delegates are cached and replacement should be fast. It also handles recursive references and should omit Stack Overflow.

StringPropReplacer.ReplaceStrings(obj, s => Encrypt(s));
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;

namespace StringReplacer
{
    public static class StringPropReplacer
    {
        private static ParameterExpression _strParam = Expression.Parameter(typeof(string), "str");
        private static ParameterExpression _originalValue = Expression.Variable(typeof(string), "original");
        private static ParameterExpression _newValue = Expression.Variable(typeof(string), "newValue");
        private static ParameterExpression _visitedParam = Expression.Parameter(typeof(HashSet<object>), "visited");
        private static ParameterExpression _replacerParam = Expression.Parameter(typeof(Func<string, string>), "replacer");
        private static ParameterExpression[] _blockVariables = new [] {_originalValue, _newValue};

        private static void ReplaceObject<T>(T obj, HashSet<object> visited, Func<string, string> replacer)
        {
            ReflectionHolder<T>.ReplaceObject(obj, visited, replacer);
        }

        private static void ReplaceObjects<T>(IEnumerable<T> objects, HashSet<object> visited, Func<string, string> replacer)
        {
            ReflectionHolder<T>.ReplaceObjects(objects, visited, replacer);
        }

        public class ObjectReferenceEqualityComparer<T> : IEqualityComparer<T>
            where T : notnull
        {
            public static IEqualityComparer<T> Default = new ObjectReferenceEqualityComparer<T>();

            #region IEqualityComparer<T> Members

            public bool Equals(T x, T y)
            {
                return ReferenceEquals(x, y);
            }

            public int GetHashCode(T obj)
            {
                if (obj == null)
                    return 0;

                return RuntimeHelpers.GetHashCode(obj);
            }

            #endregion
        }


        private static class ReflectionHolder<T>
        {
            private static Action<T, HashSet<object>, Func<string, string>> _replacer;

            static ReflectionHolder()
            {
                var objParam = Expression.Parameter(typeof(T), "obj");
                var blockBody = new List<Expression>();

                foreach (var prop in typeof(T).GetProperties())
                {
                    if (prop.PropertyType == typeof(string) && prop.CanRead && prop.CanWrite)
                    {
                        var propExpression = Expression.MakeMemberAccess(objParam, prop);
                        blockBody.Add(Expression.Assign(_originalValue, propExpression));
                        blockBody.Add(Expression.Assign(_newValue, Expression.Invoke(_replacerParam, _originalValue)));
                        blockBody.Add(Expression.IfThen(Expression.NotEqual(_originalValue, _newValue),
                            Expression.Assign(propExpression, _newValue)));
                    }
                    else if (typeof(IEnumerable).IsAssignableFrom(prop.PropertyType))
                    {
                        var intf = prop.PropertyType
                            .GetInterfaces()
                            .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>));

                        if (intf != null)
                        {
                            var propExpression = Expression.MakeMemberAccess(objParam, prop);
                            blockBody.Add(Expression.Call(typeof(StringPropReplacer), "ReplaceObjects",
                                intf.GetGenericArguments(), propExpression, _visitedParam, _replacerParam
                            ));
                        }
                    }
                    else if (prop.PropertyType.IsClass)
                    {
                        var propExpression = Expression.MakeMemberAccess(objParam, prop);
                        blockBody.Add(Expression.Call(typeof(StringPropReplacer), "ReplaceObject",
                            new[] {prop.PropertyType}, propExpression, _visitedParam, _replacerParam
                        ));
                    }
                    
                }

                if (blockBody.Count == 0)
                    _replacer = (o, v, f) => { };
                else
                {
                    var replacerExpr = Expression.Lambda<Action<T, HashSet<object>, Func<string, string>>>(
                        Expression.Block(_blockVariables, blockBody), objParam, _visitedParam, _replacerParam);
                    _replacer = replacerExpr.Compile();
                }
            }

            public static void ReplaceObject(T obj, HashSet<object> visited, Func<string, string> replacer)
            {
                if (obj == null)
                    return;

                if (!visited.Add(obj))
                    return;
                
                _replacer(obj, visited, replacer);
            }

            public static void ReplaceObjects(IEnumerable<T> objects, HashSet<object> visited, Func<string, string> replacer)
            {
                if (objects == null)
                    return;

                if (!visited.Add(objects))
                    return;

                foreach (var obj in objects)
                {
                    ReplaceObject(obj, visited, replacer);   
                }
            }
        }

        public static void ReplaceStrings<T>(T obj, Func<string, string> replacer)
        {
            ReplaceObject(obj, new HashSet<object>(ObjectReferenceEqualityComparer<object>.Default), replacer);
        }
    }
}
  • Related