Apache Spark+PyTorch 案例实战
随着数据量和复杂性的不断增加,深度学习是提供大数据预测分析解决方案的理想方法,须要增长计算处理能力和更先进的图形处理器。经过深度学习,可以利用非结构化数据(例如图像、文本和语音),应用到图像识别、自动翻译、天然语言处理等领域。图像分类:识别和分类图像,便于排序和更准确的搜索。目标检测:快速的目标检测使自动驾驶汽车和人脸识别成为现实。天然语言处理:准确理解口语,为语音到文本和智能家居提供动力。
深度学习面临的挑战:虽然大数据和人工智能提供了大量的潜力,但从大数据中提取可操做的洞察力并非一项普通的任务。隐藏在非结构化数据(图像、声音、文本等)中的大量快速增加的信息,须要先进技术的发展和跨学科团队(数据工程、数据科学和业务)的密切合做。
基于Databricks云平台可以轻松构建、训练和部署深度学习应用程序。
node
- Databricks云平台集群提供一个交互式环境,能够轻松地使用深度学习的框架,如Tensorflow、Keras、Pytorch、Mxnet、Caffe、Cntk和Theano。
- Databricks提供处理数据准备、模型训练和大规模预测的云集群平台。
- Spark分布式计算进行性能优化,能够在强大的GPU硬件上大规模运行。
- 交互式数据科学。Databricks云平台支持多种编程语言,支持实时数据集的深度学习模型训练。
目录
- Spark+PyTorch案例简介
- Spark+PyTorch案例实战
- 建立Databricks云平台集群Notebook
- 导入库
- 准备预训练模型和数据,广播ResNet50模型
- 构建函数方法
- 检查鲜花数据的文件目录
- 将图像文件名加载到Spark数据框中
- 构建自定义Pytorch的数据集类ImageDataset、 定义模型预测的函数
- 将函数包装为Pandas UDF, 经过Pandas UDF进行模型预测
- 加载保存的Parquet文件并检查预测结果
- 《Spark大数据商业实战三部曲》第二版简介
Spark+PyTorch案例简介
本文基于AWS+Databricks云平台,基于ResNet-50网络模型,使用Spark3.0.0、Pytorch1.5.1 对鲜花图像数据(郁金香、向日葵、玫瑰、蒲公英、菊花)进行分布式图像识别实战。
Spark+PyTorch案例实战示意图:
python
ResNet网络模型有不一样的网络层数,比较经常使用的是50-layer,101-layer和152-layer,都是由ResNet模块堆叠在一块儿实现的,ResNet网络结构如图所示。
ResNet-50模型的论文连接地址( https://arxiv.org/pdf/1512.03385.pdf ),论文题目:Deep Residual Learning for Image Recognition,做者:Kaiming He、Xiangyu Zhang 、Shaoqing Ren 、Jian Sun。
sql
AWS+Databricks云平台基本操做: 编程
Spark+PyTorch案例实战
使用Spark+PyTorch案例实战的步骤以下。
数组
1. 准备预训练的模型和鲜花集数据。
从torchvision.models加载预训练的ResNet-50模型。
将鲜花数据下载到databricks文件系统空间。
2. Spark加载鲜花数据,并转换为Spark数据帧。
3. 经过Pandas UDF进行模型预测。
性能优化
建立Databricks云平台集群Notebook
在Workspace栏目中单击右键,在弹出菜单栏目中选择Create,依次输入名称、开发语言(选择python语言)、云平台集群,单击Create建立Notebook。
网络
导入库
在启用CPU的Apache Spark集群上运行notebook,设置变量cuda=False。在启用GPU的Apache Spark集群上运行notebook,设置变量cuda=True; 启动 Arrow支持。Apache Arrow是一种内存中的列式数据格式,在Spark中用于高效传输JVM和Python进程之间的数据。将Spark数据帧转换为Pandas数据帧时,可使用Arrow 进行优化;导入pandas、pytorch、pyspark等库。Pytorch设置是否使用GPU。
数据结构
```python cuda = False spark.conf.set("spark.sql.execution.arrow.enabled", "true") spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "2048") import os import shutil import tarfile import time import zipfile try: from urllib.request import urlretrieve except ImportError: from urllib import urlretrieve import pandas as pd import torch from torch.utils.data import Dataset from torchvision import datasets, models, transforms from torchvision.datasets.folder import default_loader # private API from pyspark.sql.functions import col, pandas_udf, PandasUDFType from pyspark.sql.types import ArrayType, FloatType use_cuda = cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") ```
准备预训练模型和数据,广播ResNet50模型
定义输入和输出目录。建议使用Databricks Runtime 5.4 ML或更高版本,将训练集数据保存到Databricks文件系统dbfs:/ml目录,该文件映射到Driver及Worker节点上的文件/dbfs/ml。dbfs:/ml是一个特殊的文件夹,为深度学习工做负载提供高性能的I/O。在Spark Driver 节点上加载ResNet50预训练模型,并广播ResNet50模型的状态。
架构
```python URL = "http://download.tensorflow.org/example_images/flower_photos.tgz" input_local_dir = "/dbfs/ml/tmp/flower/" output_file_path = "/tmp/predictions" bc_model_state = sc.broadcast(models.resnet50(pretrained=True).state_dict()) ```
广播ResNet50预训练模型参数如图所示。
app
构建函数方法
定义get_model_for_eval方法,返回一个Pytorch ResNet50预训练模型实例,其加载Spark广播变量ResNet50模型的参数。定义maybe_download_and_extract方法。从Tensorflow网站(http://download.tensorflow.org/example_images/flower_photos.tgz)下载鲜花文件并解压缩,解压的文件包括郁金香、向日葵、玫瑰、蒲公英、菊花等图像类型。
```python def get_model_for_eval(): """Gets the broadcasted model.""" model = models.resnet50(pretrained=True) model.load_state_dict(bc_model_state.value) model.eval() return model def maybe_download_and_extract(url, download_dir): filename = url.split('/')[-1] file_path = os.path.join(download_dir, filename) print(file_path) if not os.path.exists(file_path): if not os.path.exists(download_dir): os.makedirs(download_dir) file_path, _ = urlretrieve(url=url, filename=file_path) print() print("Download finished. Extracting files.") if file_path.endswith(".zip"): # Unpack the zip-file. zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir) elif file_path.endswith((".tar.gz", ".tgz")): # Unpack the tar-ball. tarfile.open(name=file_path, mode="r:gz").extractall(download_dir) print("Done.") else: print("Data has apparently already been downloaded and unpacked.") maybe_download_and_extract(url=URL, download_dir=input_local_dir) ```
检查鲜花数据的文件目录
在Databricks文件系统中检查已经下载鲜花数据的文件目录。在Databricks云平台Notebook中运行上述代码,运行结果以下:
```python print(dbutils.fs.ls("dbfs:/ml/tmp/flower_photos/")) [FileInfo(path='dbfs:/ml/tmp/flower_photos/LICENSE.txt', name='LICENSE.txt', size=418049), FileInfo(path='dbfs:/ml/tmp/flower_photos/daisy/', name='daisy/', size=0), FileInfo(path='dbfs:/ml/tmp/flower_photos/dandelion/', name='dandelion/', size=0), FileInfo(path='dbfs:/ml/tmp/flower_photos/roses/', name='roses/', size=0), FileInfo(path='dbfs:/ml/tmp/flower_photos/sunflowers/', name='sunflowers/', size=0), FileInfo(path='dbfs:/ml/tmp/flower_photos/tulips/', name='tulips/', size=0)] ```
查看Databricks文件系统中菊花目录的文件信息。
```python print(dbutils.fs.ls("dbfs:/ml/tmp/flower_photos/daisy/")) [FileInfo(path='dbfs:/ml/tmp/flower_photos/daisy/100080576_f52e8ee070_n.jpg', name='100080576_f52e8ee070_n.jpg', size=26797), FileInfo(path='dbfs:/ml/tmp/flower_photos/daisy/10140303196_b88d3d6cec.jpg', name='10140303196_b88d3d6cec.jpg', size=117247), FileInfo(path='dbfs:/ml/tmp/flower_photos/daisy/10172379554_b296050f82_n.jpg', name='10172379554_b296050f82_n.jpg', size=36410), FileInfo ...... ```
也能够将鲜花文件下载到本地电脑,查看鲜花目录如图所示。
单击向日葵的文件目录,查看向日葵的图片如图所示。
获取鲜花数据集各目录中图像文件的数量。
```python local_dir = input_local_dir + 'flower_photos/' files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(local_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg'] len(files) ```
在Databricks云平台Notebook中运行上述代码,运行结果以下:
```python Out[44]: 3670 ```
将图像文件名加载到Spark数据框中
```python files_df = spark.createDataFrame( map(lambda path: (path,), files), ["path"] ).repartition(10) # number of partitions should be a small multiple of total number of nodes display(files_df.limit(10)) ```
- 第1行代码调用spark.createDataFrame方法建立数据帧。
- 第2行代码中createDataFrame方法的第一个输入参数是map函数,在map函数中遍历每个图像的文件名,调用匿名函数将每个文件名组成(path,)的格式;createDataFrame方法的第二个参数是数据帧的列名。
- 第3行代码调用Spark的repartition方法进行重分区,将图像文件名的数据分为10个分区。
在Databricks云平台Notebook中运行上述代码,展现10条记录的图像路径及文件名,运行结果如图所示:
单击图中的View文本,能够查询Databricks云平台Spark Jobs的执行状况,如图所示。
构建自定义Pytorch的数据集类ImageDataset、 定义模型预测的函数
```python class ImageDataset(Dataset): def __init__(self, paths, transform=None): self.paths = paths self.transform = transform def __len__(self): return len(self.paths) def __getitem__(self, index): image = default_loader(self.paths[index]) if self.transform is not None: image = self.transform(image) return image ```
定义模型预测的函数
def predict_batch(paths): transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) images = ImageDataset(paths, transform=transform) loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8) model = get_model_for_eval() model.to(device) all_predictions = [] with torch.no_grad(): for batch in loader: predictions = list(model(batch.to(device)).cpu().numpy()) for prediction in predictions: all_predictions.append(prediction) return pd.Series(all_predictions)
- 第2行代码使用PyTorch对图像数据进行数据加强,将多个数据转换步骤整合在一块儿。
- 第3行代码按给定大小进行图像尺寸变化。
- 第4行代码图像中心收缩到给定的大小。
- 第5行代码将图像数据或者数组转换为Tensor数据结构。
- 第6行代码对图像数据按通道进行标准化处理。
- 第10行代码调用PyTorch的torch.utils.data.DataLoader方法加载图像数据,每一个批次包含500张图像。
- 第11行代码获取Spark广播的ResNet50预训练模型参数实例。
- 第16行代码获取每批次图像数据的预测分类。
在本地测试函数。
```python predictions = predict_batch(pd.Series(files[:200])) ```
将函数包装为Pandas UDF, 经过Pandas UDF进行模型预测
```python predict_udf = pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR)(predict_batch) predictions_df = files_df.select(col('path'), predict_udf(col('path')).alias("prediction")) predictions_df.write.mode("overwrite").parquet(output_file_path) ```
Pandas_UDF是在PySpark 2.3版本中新增的API,Spark经过Arrow传输数据,使用Pandas处理数据。Pandas_UDF使用关键字pandas_udf做为装饰器或声明一个函数进行定义, Pandas_UDF包括Scalar(标量映射)和Grouped Map(分组映射)等类型。在Databricks云平台Notebook中运行上述代码,运行结果如图所示:
加载保存的Parquet文件并检查预测结果
```python result_df = spark.read.load(output_file_path) display(result_df) ```
在Databricks云平台Notebook中运行上述代码,运行结果如图所示:
其中预测分类值是一个大小为1000的数组,根据ResNet-50模型预测1000个分类的几率。本案例鲜花的分类实际为5类:郁金香、向日葵、玫瑰、蒲公英、菊花,感兴趣的同窗能够改写ResNet-50模型代码进行优化,获得5个类别的预测值。
《Spark大数据商业实战三部曲》第二版简介
https://duanzhihua.blog.csdn.net/article/details/106294896
在大数据和AI紧密协同时代,最佳的AI系统依赖海量数据才能构建出高度复杂的模型,海量数据须要借助Al才能挖掘出终极价值。本书以数据智能为灵魂,以Spark 2.4.X版本为载体,以Spark+ AI商业案例实战和生产环境下几乎全部类型的性能调优为核心,对企业生产环境下的Spark+AI商业案例与性能调优抽丝剥茧地进行剖析。全书共分4篇,内核解密篇基于Spark源码,从一个实战案例入手,按部就班地全面解析Spark 2.4.X版本的新特性及Spark内核源码;商业案例篇选取Spark开发中最具表明性的经典学习案例,在案例中综合介绍Spark的大数据技术;性能调优篇覆盖Spark在生产环境下的全部调优技术; Spark+ AI内幕解密篇讲解深度学习动手实践,经过整合Spark、PyTorch以及TensorFlow揭秘Spark上的深度学习内幕。
本书适合全部大数据和人工智能学习者及从业人员使用。对于有丰富大数据和AI应用经验的人员,本书也能够做为大数据和AI高手修炼的参考用书。同时,本书也特别适合做为高等院校的大数据和人工智能教材。
做者简介
王家林,Apache Spark执牛耳者现工做于硅谷的AI实验室,专一于NLP框架超过20本Spark、Al、Android书籍做者Toastmasters International Division Director GRE博士入学考试连续两次满分得到者
段智华,就任于中国电信股份有限公司上海分公司,系统架构师,CSDN博客专家,专一于Spark大数据技术研发及推广,跟随Spark核心源码技术的发展,深刻研究Spark 2.1.1版本及Spark 2.4.0版本的源码优化,对Spark大数据处理、机器学习等技术有丰富的实战经验和浓厚兴趣。