2020/01/28, ASP.NET Core 3.1, VS2019, Microsoft.EntityFrameworkCore.Relational 3.1.1html
摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构【4-工做单元和仓储设计】
使用泛型仓储(Repository)和工做单元(UnitOfWork)模式封装数据访问层基础的增删改查等方法git
文章目录github
此分支项目代码sql
关于本章节的工做单元模式:
泛型仓储封装了通用的增删改查方法,由工做单元统一管理仓储以保证数据库上下文一致性。
要获取仓储,都从工做单元中获取,经过仓储改动数据库后,由工做单元进行提交。
代码参考Arch/UnitOfWork的设计,大部分都是参考他的,而后作了一些中文注释,去除了分布式多库支持数据库
向MS.UnitOfWork
项目添加对Microsoft.EntityFrameworkCore.Relational
包的引用:后端
<ItemGroup> <PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.1" /> </ItemGroup>
在MS.UnitOfWork
项目中添加Collections文件夹,在该文件夹下添加IPagedList.cs
、PagedList.cs
、IEnumerablePagedListExtensions.cs
、IQueryablePageListExtensions.cs
类。缓存
using System.Collections.Generic; namespace MS.UnitOfWork.Collections { /// <summary> /// 提供任何类型的分页接口 /// </summary> /// <typeparam name="T">须要分页的数据类型</typeparam> public interface IPagedList<T> { /// <summary> /// 起始页 值 /// </summary> int IndexFrom { get; } /// <summary> /// 当前页 值 /// </summary> int PageIndex { get; } /// <summary> /// 每页大小 /// </summary> int PageSize { get; } /// <summary> /// 数据总数 /// </summary> int TotalCount { get; } /// <summary> /// 总页数 /// </summary> int TotalPages { get; } /// <summary> /// 当前页数据 /// </summary> IList<T> Items { get; } /// <summary> /// 是否有上一页 /// </summary> bool HasPreviousPage { get; } /// <summary> /// 是否有下一页 /// </summary> bool HasNextPage { get; } } }
using System; using System.Collections.Generic; using System.Linq; namespace MS.UnitOfWork.Collections { /// <summary> /// 提供数据的分页,<see cref="IPagedList{T}"/>的默认实现 /// </summary> /// <typeparam name="T"></typeparam> public class PagedList<T> : IPagedList<T> { /// <summary> /// 当前页 值 /// </summary> public int PageIndex { get; set; } /// <summary> /// 每页大小 /// </summary> public int PageSize { get; set; } /// <summary> /// 数据总数 /// </summary> public int TotalCount { get; set; } /// <summary> /// 总页数 /// </summary> public int TotalPages { get; set; } /// <summary> /// 起始页 值 /// </summary> public int IndexFrom { get; set; } /// <summary> /// 当前页数据 /// </summary> public IList<T> Items { get; set; } /// <summary> /// 是否有上一页 /// </summary> public bool HasPreviousPage => PageIndex - IndexFrom > 0; /// <summary> /// 是否有下一页 /// </summary> public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages; /// <summary> /// 初始化实例 /// </summary> /// <param name="source">The source.</param> /// <param name="pageIndex">The index of the page.</param> /// <param name="pageSize">The size of the page.</param> /// <param name="indexFrom">The index from.</param> internal PagedList(IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom) { if (indexFrom > pageIndex) { throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页"); } if (source is IQueryable<T> querable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; TotalCount = querable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); } else { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; TotalCount = source.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); Items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList(); } } /// <summary> /// Initializes a new instance of the <see cref="PagedList{T}" /> class. /// </summary> internal PagedList() => Items = new T[0]; } /// <summary> /// 提供数据的分页,并支持数据类型转换 /// </summary> /// <typeparam name="TSource">数据源类型</typeparam> /// <typeparam name="TResult">输出数据类型</typeparam> internal class PagedList<TSource, TResult> : IPagedList<TResult> { /// <summary> /// 当前页 值 /// </summary> public int PageIndex { get; set; } /// <summary> /// 每页大小 /// </summary> public int PageSize { get; set; } /// <summary> /// 数据总数 /// </summary> public int TotalCount { get; set; } /// <summary> /// 总页数 /// </summary> public int TotalPages { get; set; } /// <summary> /// 起始页 值 /// </summary> public int IndexFrom { get; set; } /// <summary> /// 当前页数据 /// </summary> public IList<TResult> Items { get; set; } /// <summary> /// 是否有上一页 /// </summary> public bool HasPreviousPage => PageIndex - IndexFrom > 0; /// <summary> /// 是否有下一页 /// </summary> public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages; /// <summary> /// 初始化实例 /// </summary> /// <param name="source">The source.</param> /// <param name="converter">The converter.</param> /// <param name="pageIndex">The index of the page.</param> /// <param name="pageSize">The size of the page.</param> /// <param name="indexFrom">The index from.</param> public PagedList(IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom) { if (indexFrom > pageIndex) { throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页"); } if (source is IQueryable<TSource> querable) { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; TotalCount = querable.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List<TResult>(converter(items)); } else { PageIndex = pageIndex; PageSize = pageSize; IndexFrom = indexFrom; TotalCount = source.Count(); TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize); var items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray(); Items = new List<TResult>(converter(items)); } } /// <summary> /// Initializes a new instance of the <see cref="PagedList{TSource, TResult}" /> class. /// </summary> /// <param name="source">The source.</param> /// <param name="converter">The converter.</param> public PagedList(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) { PageIndex = source.PageIndex; PageSize = source.PageSize; IndexFrom = source.IndexFrom; TotalCount = source.TotalCount; TotalPages = source.TotalPages; Items = new List<TResult>(converter(source.Items)); } } /// <summary> /// Provides some help methods for <see cref="IPagedList{T}"/> interface. /// </summary> public static class PagedList { /// <summary> /// Creates an empty of <see cref="IPagedList{T}"/>. /// </summary> /// <typeparam name="T">The type for paging </typeparam> /// <returns>An empty instance of <see cref="IPagedList{T}"/>.</returns> public static IPagedList<T> Empty<T>() => new PagedList<T>(); /// <summary> /// Creates a new instance of <see cref="IPagedList{TResult}"/> from source of <see cref="IPagedList{TSource}"/> instance. /// </summary> /// <typeparam name="TResult">The type of the result.</typeparam> /// <typeparam name="TSource">The type of the source.</typeparam> /// <param name="source">The source.</param> /// <param name="converter">The converter.</param> /// <returns>An instance of <see cref="IPagedList{TResult}"/>.</returns> public static IPagedList<TResult> From<TResult, TSource>(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) => new PagedList<TSource, TResult>(source, converter); } }
using System; using System.Collections.Generic; namespace MS.UnitOfWork.Collections { /// <summary> /// 给<see cref="IEnumerable{T}"/>添加扩展方法来支持分页 /// </summary> public static class IEnumerablePagedListExtensions { /// <summary> /// 在数据中取得固定页的数据 /// </summary> /// <typeparam name="T">数据类型</typeparam> /// <param name="source">数据源</param> /// <param name="pageIndex">当前页</param> /// <param name="pageSize">页大小</param> /// <param name="indexFrom">起始页</param> /// <returns></returns> public static IPagedList<T> ToPagedList<T>(this IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<T>(source, pageIndex, pageSize, indexFrom); /// <summary> /// 在数据中取得固定页数据,并转换为指定数据类型 /// </summary> /// <typeparam name="TSource">数据源类型</typeparam> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="source">数据源</param> /// <param name="converter"></param> /// <param name="pageIndex">当前页</param> /// <param name="pageSize">页大小</param> /// <param name="indexFrom">起始页</param> /// <returns></returns> public static IPagedList<TResult> ToPagedList<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<TSource, TResult>(source, converter, pageIndex, pageSize, indexFrom); } }
using Microsoft.EntityFrameworkCore; using System; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace MS.UnitOfWork.Collections { public static class IQueryablePageListExtensions { /// <summary> /// 在数据中取得固定页的数据(异步操做) /// </summary> /// <typeparam name="T">数据类型</typeparam> /// <param name="source">数据源</param> /// <param name="pageIndex">当前页</param> /// <param name="pageSize">页大小</param> /// <param name="indexFrom">起始页</param> /// <param name="cancellationToken">异步观察参数</param> /// <returns></returns> public static async Task<IPagedList<T>> ToPagedListAsync<T>(this IQueryable<T> source, int pageIndex, int pageSize, int indexFrom = 1, CancellationToken cancellationToken = default(CancellationToken)) { if (indexFrom > pageIndex) { throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex"); } var count = await source.CountAsync(cancellationToken).ConfigureAwait(false); var items = await source.Skip((pageIndex - indexFrom) * pageSize) .Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false); var pagedList = new PagedList<T>() { PageIndex = pageIndex, PageSize = pageSize, IndexFrom = indexFrom, TotalCount = count, Items = items, TotalPages = (int)Math.Ceiling(count / (double)pageSize) }; return pagedList; } } }
针对IQueryable、IEnumerable类型的数据作了分页扩展方法封装,主要用于向数据库获取数据时进行分页筛选架构
在MS.UnitOfWork
项目中添加Repository文件夹,在该文件夹下添加IRepository.cs
、Repository.cs
类。app
using MS.UnitOfWork.Collections; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Query; using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; namespace MS.UnitOfWork { /// <summary> /// 通用仓储接口 /// </summary> /// <typeparam name="TEntity"></typeparam> public interface IRepository<TEntity> where TEntity : class { #region GetAll /// <summary> ///获取全部实体 ///注意性能! /// </summary> /// <returns>The <see cref="IQueryable{TEntity}"/>.</returns> IQueryable<TEntity> GetAll(); /// <summary> /// 获取全部实体 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> IQueryable<TEntity> GetAll( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false); /// <summary> /// 获取全部实体,必须提供筛选谓词 /// </summary> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="selector">投影选择器</param> /// <param name="predicate">筛选谓词</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <returns></returns> IQueryable<TResult> GetAll<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false ) where TResult : class; /// <summary> /// 获取全部实体 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> Task<IList<TEntity>> GetAllAsync( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false); #endregion #region GetPagedList /// <summary> /// 获取分页数据 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="pageIndex">当前页。默认第一页</param> /// <param name="pageSize">页大小。默认20笔数据</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> IPagedList<TEntity> GetPagedList( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false); /// <summary> /// 获取分页数据 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="pageIndex">当前页。默认第一页</param> /// <param name="pageSize">页大小。默认20笔数据</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <param name="cancellationToken">异步token</param> /// <returns></returns> Task<IPagedList<TEntity>> GetPagedListAsync( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default); /// <summary> /// 获取分页数据 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="selector">投影选择器</param> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="pageIndex">当前页。默认第一页</param> /// <param name="pageSize">页大小。默认20笔数据</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> IPagedList<TResult> GetPagedList<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false ) where TResult : class; /// <summary> /// 获取分页数据 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="selector">投影选择器</param> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="pageIndex">当前页。默认第一页</param> /// <param name="pageSize">页大小。默认20笔数据</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <param name="cancellationToken">异步token</param> /// <returns></returns> Task<IPagedList<TResult>> GetPagedListAsync<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default) where TResult : class; #endregion #region GetFirstOrDefault /// <summary> /// 获取知足条件的序列中的第一个元素 /// 若是没有元素知足条件,则返回默认值 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> TEntity GetFirstOrDefault( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false); /// <summary> /// 获取知足条件的序列中的第一个元素 /// 若是没有元素知足条件,则返回默认值 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <param name="cancellationToken">异步token</param> /// <returns></returns> Task<TEntity> GetFirstOrDefaultAsync( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default); /// <summary> /// 获取知足条件的序列中的第一个元素 /// 若是没有元素知足条件,则返回默认值 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="selector">投影选择器</param> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <returns></returns> TResult GetFirstOrDefault<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false); /// <summary> /// 获取知足条件的序列中的第一个元素 /// 若是没有元素知足条件,则返回默认值 /// 默认是关闭追踪查询的(拿到的数据默认只读) /// 默认开启全局查询筛选过滤 /// </summary> /// <typeparam name="TResult">输出数据类型</typeparam> /// <param name="selector">投影选择器</param> /// <param name="predicate">条件表达式</param> /// <param name="orderBy">排序</param> /// <param name="include">包含的导航属性</param> /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param> /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param> /// <param name="cancellationToken">异步token</param> /// <returns></returns> Task<TResult> GetFirstOrDefaultAsync<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default); #endregion #region Find /// <summary> /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// </summary> /// <param name="keyValues">The values of the primary key for the entity to be found.</param> /// <returns>The found entity or null.</returns> TEntity Find(params object[] keyValues); /// <summary> /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// </summary> /// <param name="keyValues">The values of the primary key for the entity to be found.</param> /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns> ValueTask<TEntity> FindAsync(params object[] keyValues); /// <summary> /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned. /// </summary> /// <param name="keyValues">The values of the primary key for the entity to be found.</param> /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param> /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns> ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken); #endregion #region sql、count、exist /// <summary> /// 使用原生sql查询来获取指定数据 /// </summary> /// <param name="sql"></param> /// <param name="parameters"></param> /// <returns></returns> IQueryable<TEntity> FromSql(string sql, params object[] parameters); /// <summary> /// 查询数量 /// </summary> /// <param name="predicate"></param> /// <returns></returns> int Count(Expression<Func<TEntity, bool>> predicate = null); /// <summary> /// 查询数量 /// </summary> /// <param name="predicate"></param> /// <returns></returns> Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null); /// <summary> /// 按指定条件元素是否存在 /// </summary> /// <param name="predicate"></param> /// <returns></returns> bool Exists(Expression<Func<TEntity, bool>> predicate = null); #endregion #region Insert /// <summary> /// Inserts a new entity synchronously. /// </summary> /// <param name="entity"></param> /// <returns></returns> TEntity Insert(TEntity entity); /// <summary> /// Inserts a range of entities synchronously. /// </summary> /// <param name="entities">The entities to insert.</param> void Insert(params TEntity[] entities); /// <summary> /// Inserts a range of entities synchronously. /// </summary> /// <param name="entities">The entities to insert.</param> void Insert(IEnumerable<TEntity> entities); /// <summary> /// Inserts a new entity asynchronously. /// </summary> /// <param name="entity">The entity to insert.</param> /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param> /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns> ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default); /// <summary> /// Inserts a range of entities asynchronously. /// </summary> /// <param name="entities">The entities to insert.</param> /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns> Task InsertAsync(params TEntity[] entities); /// <summary> /// Inserts a range of entities asynchronously. /// </summary> /// <param name="entities">The entities to insert.</param> /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param> /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns> Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default); #endregion #region Update /// <summary> /// Updates the specified entity. /// </summary> /// <param name="entity">The entity.</param> void Update(TEntity entity); /// <summary> /// Updates the specified entities. /// </summary> /// <param name="entities">The entities.</param> void Update(params TEntity[] entities); /// <summary> /// Updates the specified entities. /// </summary> /// <param name="entities">The entities.</param> void Update(IEnumerable<TEntity> entities); #endregion #region Delete /// <summary> /// Deletes the entity by the specified primary key. /// </summary> /// <param name="id">The primary key value.</param> void Delete(object id); /// <summary> /// Deletes the specified entity. /// </summary> /// <param name="entity">The entity to delete.</param> void Delete(TEntity entity); /// <summary> /// Deletes the specified entities. /// </summary> /// <param name="entities">The entities.</param> void Delete(params TEntity[] entities); /// <summary> /// Deletes the specified entities. /// </summary> /// <param name="entities">The entities.</param> void Delete(IEnumerable<TEntity> entities); #endregion } }
using MS.UnitOfWork.Collections; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Query; using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Threading; using System.Threading.Tasks; namespace MS.UnitOfWork { /// <summary> /// 通用仓储的默认实现 /// </summary> /// <typeparam name="TEntity"></typeparam> public class Repository<TEntity> : IRepository<TEntity> where TEntity : class { protected readonly DbContext _dbContext; protected readonly DbSet<TEntity> _dbSet; public Repository(DbContext dbContext) { _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); _dbSet = _dbContext.Set<TEntity>(); } #region GetAll public IQueryable<TEntity> GetAll() => _dbSet; public IQueryable<TEntity> GetAll( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query); } else { return query; } } public IQueryable<TResult> GetAll<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) where TResult : class { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query).Select(selector); } else { return query.Select(selector); } } public async Task<IList<TEntity>> GetAllAsync(Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return await orderBy(query).ToListAsync(); } else { return await query.ToListAsync(); } } #endregion #region GetPagedList public virtual IPagedList<TEntity> GetPagedList( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query).ToPagedList(pageIndex, pageSize); } else { return query.ToPagedList(pageIndex, pageSize); } } public virtual async Task<IPagedList<TEntity>> GetPagedListAsync( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return await orderBy(query).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken); } else { return await query.ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken); } } public virtual IPagedList<TResult> GetPagedList<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false) where TResult : class { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize); } else { return query.Select(selector).ToPagedList(pageIndex, pageSize); } } public virtual async Task<IPagedList<TResult>> GetPagedListAsync<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, int pageIndex = 1, int pageSize = 20, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default) where TResult : class { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return await orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken); } else { return await query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken); } } #endregion #region GetFirstOrDefault public virtual TEntity GetFirstOrDefault( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query).FirstOrDefault(); } else { return query.FirstOrDefault(); } } public virtual async Task<TEntity> GetFirstOrDefaultAsync( Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return await orderBy(query).FirstOrDefaultAsync(cancellationToken); } else { return await query.FirstOrDefaultAsync(cancellationToken); } } public virtual TResult GetFirstOrDefault<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return orderBy(query).Select(selector).FirstOrDefault(); } else { return query.Select(selector).FirstOrDefault(); } } public virtual async Task<TResult> GetFirstOrDefaultAsync<TResult>( Expression<Func<TEntity, TResult>> selector, Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false, CancellationToken cancellationToken = default) { IQueryable<TEntity> query = _dbSet; if (disableTracking) { query = query.AsNoTracking(); } if (include != null) { query = include(query); } if (predicate != null) { query = query.Where(predicate); } if (ignoreQueryFilters) { query = query.IgnoreQueryFilters(); } if (orderBy != null) { return await orderBy(query).Select(selector).FirstOrDefaultAsync(cancellationToken); } else { return await query.Select(selector).FirstOrDefaultAsync(cancellationToken); } } #endregion #region Find public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues); public virtual ValueTask<TEntity> FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues); public virtual ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken); #endregion #region sql、count、exist public virtual IQueryable<TEntity> FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters); public virtual int Count(Expression<Func<TEntity, bool>> predicate = null) { if (predicate == null) { return _dbSet.Count(); } else { return _dbSet.Count(predicate); } } public virtual async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null) { if (predicate == null) { return await _dbSet.CountAsync(); } else { return await _dbSet.CountAsync(predicate); } } public virtual bool Exists(Expression<Func<TEntity, bool>> predicate = null) { if (predicate == null) { return _dbSet.Any(); } else { return _dbSet.Any(predicate); } } #endregion #region Insert public virtual TEntity Insert(TEntity entity) { return _dbSet.Add(entity).Entity; } public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities); public virtual void Insert(IEnumerable<TEntity> entities) => _dbSet.AddRange(entities); public virtual ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken)) { return _dbSet.AddAsync(entity, cancellationToken); // Shadow properties? //var property = _dbContext.Entry(entity).Property("Created"); //if (property != null) { //property.CurrentValue = DateTime.Now; //} } public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities); public virtual Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken); #endregion #region Update public virtual void Update(TEntity entity) { _dbSet.Update(entity); } public virtual void UpdateAsync(TEntity entity) { _dbSet.Update(entity); } public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities); public virtual void Update(IEnumerable<TEntity> entities) => _dbSet.UpdateRange(entities); #endregion #region Delete public virtual void Delete(TEntity entity) => _dbSet.Remove(entity); public virtual void Delete(object id) { var entity = _dbSet.Find(id); if (entity != null) { Delete(entity); } } public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities); public virtual void Delete(IEnumerable<TEntity> entities) => _dbSet.RemoveRange(entities); #endregion } }
在MS.UnitOfWork
项目中添加UnitOfWork文件夹,在该文件夹下添加IUnitOfWork.cs
、UnitOfWork.cs
类。异步
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Storage; using System; using System.Linq; using System.Threading.Tasks; namespace MS.UnitOfWork { /// <summary> /// 定义工做单元接口 /// </summary> public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext { /// <summary> /// 获取DBContext /// </summary> /// <returns></returns> TContext DbContext { get; } /// <summary> /// 开始一个事务 /// </summary> /// <returns></returns> IDbContextTransaction BeginTransaction(); /// <summary> /// 获取指定仓储 /// </summary> /// <typeparam name="TEntity"></typeparam> /// <param name="hasCustomRepository">若有自定义仓储设为True</param> /// <returns></returns> IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class; /// <summary> /// DbContext提交修改 /// </summary> /// <returns></returns> int SaveChanges(); /// <summary> /// DbContext提交修改(异步) /// </summary> /// <returns></returns> Task<int> SaveChangesAsync(); /// <summary> /// 执行原生sql语句 /// </summary> /// <param name="sql">sql语句</param> /// <param name="parameters">参数</param> /// <returns></returns> int ExecuteSqlCommand(string sql, params object[] parameters); /// <summary> /// 使用原生sql查询来获取指定数据 /// </summary> /// <typeparam name="TEntity"></typeparam> /// <param name="sql"></param> /// <param name="parameters">参数</param> /// <returns></returns> IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class; } }
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; namespace MS.UnitOfWork { /// <summary> /// 工做单元的默认实现. /// </summary> /// <typeparam name="TContext"></typeparam> public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext { protected readonly TContext _context; protected bool _disposed = false; protected Dictionary<Type, object> _repositories; public UnitOfWork(TContext context) { _context = context ?? throw new ArgumentNullException(nameof(context)); } /// <summary> /// 获取DbContext /// </summary> public TContext DbContext => _context; /// <summary> /// 开始一个事务 /// </summary> /// <returns></returns> public IDbContextTransaction BeginTransaction() { return _context.Database.BeginTransaction(); } /// <summary> /// 获取指定仓储 /// </summary> /// <typeparam name="TEntity"></typeparam> /// <param name="hasCustomRepository"></param> /// <returns></returns> public IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class { if (_repositories == null) { _repositories = new Dictionary<Type, object>(); } Type type = typeof(IRepository<TEntity>); if (!_repositories.TryGetValue(type, out object repo)) { IRepository<TEntity> newRepo = new Repository<TEntity>(_context); _repositories.Add(type, newRepo); return newRepo; } return (IRepository<TEntity>)repo; } /// <summary> /// 执行原生sql语句 /// </summary> /// <param name="sql">sql语句</param> /// <param name="parameters">参数</param> /// <returns></returns> public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters); /// <summary> /// 使用原生sql查询来获取指定数据 /// </summary> /// <typeparam name="TEntity"></typeparam> /// <param name="sql"></param> /// <param name="parameters">参数</param> /// <returns></returns> public IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class => _context.Set<TEntity>().FromSqlRaw(sql, parameters); /// <summary> /// DbContext提交修改 /// </summary> /// <returns></returns> public int SaveChanges() { return _context.SaveChanges(); } /// <summary> /// DbContext提交修改(异步) /// </summary> /// <returns></returns> public async Task<int> SaveChangesAsync() { return await _context.SaveChangesAsync(); } public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } protected virtual void Dispose(bool disposing) { if (!_disposed) { if (disposing) { // clear repositories if (_repositories != null) { _repositories.Clear(); } // dispose the db context. _context.Dispose(); } } _disposed = true; } } }
在MS.UnitOfWork
项目中添加UnitOfWorkServiceExtensions.cs
类:
using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; namespace MS.UnitOfWork { /// <summary> ///在 <see cref="IServiceCollection"/>中安装工做单元依赖注入的扩展方法 /// </summary> public static class UnitOfWorkServiceExtensions { /// <summary> /// 在<see cref ="IServiceCollection"/>中注册给定上下文做为服务的工做单元。 /// 同时注册了dbcontext /// </summary> /// <typeparam name="TContext"></typeparam> /// <param name="services"></param> /// <remarks>此方法仅支持一个db上下文,若是屡次调用,将抛出异常。</remarks> /// <returns></returns> public static IServiceCollection AddUnitOfWorkService<TContext>(this IServiceCollection services, System.Action<DbContextOptionsBuilder> action) where TContext : DbContext { //注册dbcontext services.AddDbContext<TContext>(action); //注册工做单元 services.AddScoped<IUnitOfWork<TContext>, UnitOfWork<TContext>>(); return services; } } }
这样一来,若是项目要使用该工做单元,直接在Startup中调用AddUnitOfWorkService注册便可
项目完成后,以下图所示:
using (var tran = _unitOfWork.BeginTransaction())//开启一个事务 { Role newRow = _mapper.Map<Role>(viewModel); newRow.Id = _idWorker.NextId();//获取一个雪花Id newRow.Creator = 1219490056771866624;//因为暂时尚未作登陆,因此拿不到登陆者信息,先随便写一个后面再完善 newRow.CreateTime = DateTime.Now; _unitOfWork.GetRepository<Role>().Insert(newRow); await _unitOfWork.SaveChangesAsync(); await tran.CommitAsync();//提交事务 }
以上展现了工做单元开启事务,用using包裹,直到tran.CommitAsync()
才提交事务,若是遇到错误,会自动回滚
//从数据库中取出该记录 var row = await _unitOfWork.GetRepository<Role>().FindAsync(viewModel.Id);//在viewModel.CheckField中已经获取了一次用于检查,因此此处不会重复再从数据库取一次,有缓存 //修改对应的值 row.Name = viewModel.Name; row.DisplayName = viewModel.DisplayName; row.Remark = viewModel.Remark; row.Modifier = 1219490056771866624;//因为暂时尚未作登陆,因此拿不到登陆者信息,先随便写一个后面再完善 row.ModifyTime = DateTime.Now; _unitOfWork.GetRepository<Role>().Update(row); await _unitOfWork.SaveChangesAsync();//提交