使用emnist数据集进行简单的FedAvg算法

import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff# 测试tff是否安装成功
# print(tff.federated_computation(lambda: 'Hello World')())
# 加载数据集
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(cache_dir='/home/cqx/PycharmProjects/cache/fed_emnist_digitsonly')
# 查看数据集长度和结构
print(len(emnist_train.client_ids))
print(emnist_train.element_type_structure)
# 给指定客户端创造数据集 返回值tf.data.Dataset` object.
example_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0])
# iter迭代,Iterator对象可以被next()函数调用并不断返回下一个数据,直到没有数据时抛出StopIteration错误。
example_element = next(iter(example_dataset))
print(example_element['label'].numpy())# 使用数据集转换完成预处理。
# 在这里,我们将图像拉平到数组中,将各个示例打乱,并将它们组织成批次,然后重命名特征
# 客户端数目
NUM_CLIENTS = 10
# 训练次数
NUM_EPOCHS = 5
# 批次大小
BATCH_SIZE = 20
# 随机打乱
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10def preprocess(dataset):def batch_format_fn(element):"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""return collections.OrderedDict(x=tf.reshape(element['pixels'], [-1, 784]),y=tf.reshape(element['label'], [-1, 1]))# repeat(count) 将数据重复count次# shuffle(shuffleSize,seed)# dataset.examples.batch(20).prefetch(2) 预取(2批,每批20个例子)return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)preprocessed_example_dataset = preprocess(example_dataset)# a = [24, 76, "ab"]
# tf.nest.map_structure(lambda p: p * 2, a)
# [48, 152, 'abab']
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),next(iter(preprocessed_example_dataset)))
print(len(sample_batch['y']))
print(sample_batch)# 从给定的一组用户作为一轮培训或评估的输入.
def make_federated_data(client_data, client_ids):return [preprocess(client_data.create_tf_dataset_for_client(x))for x in client_ids]# 构造客户端数据
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(emnist_train, sample_clients)
print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')# 建立网络模型
def create_keras_model():return tf.keras.models.Sequential([tf.keras.layers.InputLayer(input_shape=(784,)),tf.keras.layers.Dense(10, kernel_initializer='zeros'),tf.keras.layers.Softmax(),])# 为了将任何模型与 TFF 一起使用,需要将其包装在 tff.learning.Model 接口的实例中
# 将模型和示例数据批处理作为参数
def model_fn():keras_model = create_keras_model()return tff.learning.from_keras_model(keras_model,input_spec=preprocessed_example_dataset.element_spec,loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])# 联邦平均算法实现
# client优化器仅用于计算每个客户端上的本地模型更新。
# server优化器将平均更新应用于全局模型更新
training_process = tff.learning.algorithms.build_weighted_fed_avg(model_fn,client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
# 输出服务器上的FedAVG进程。
print('可视化服务器上的FedAVG进程')
print(training_process.initialize.type_signature.formatted_representation())
# 初始化服务器状态
train_state = training_process.initialize()
# next
# 发送服务器状态 (包括模型参数)给客户,
# 在他们的设备上进行训练本地数据,收集和平均模型更新,并生成新的更新
# 发送给服务器,更新全局模型# 进行一次训练
# result = training_process.next(train_state, federated_train_data)
# train_state = result.state
# train_metrics = result.metrics
# print('round  1, metrics={}'.format(train_metrics))# 训练多轮
NUM_ROUNDS = 11
for round_num in range(1, NUM_ROUNDS):result = training_process.next(train_state, federated_train_data)train_state = result.statetrain_metrics = result.metricsprint('round {:2d}, metrics={}'.format(round_num, train_metrics))

结果展示

3383
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
1
20
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]]))])
Number of client datasets: 10
First dataset: <PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>
可视化服务器上的FedAVG进程
( -> <
  global_model_weights=<
    trainable=<
      float32[784,10],
      float32[10]
    >,
    non_trainable=<>
  >,
  distributor=<>,
  client_work=<>,
  aggregator=<
    value_sum_process=<>,
    weight_sum_process=<>
  >,
  finalizer=<
    int64
  >
>@SERVER)
round  1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834726), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.8616652), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.5297604), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053502), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.3153887), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240258), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164267), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

Process finished with exit code 0

联邦学习实战2(基于TFF)相关推荐

  1. 【联邦学习实战】基于同态加密和差分隐私混合加密机制的FedAvg

    联邦学习实战--基于同态加密和差分隐私混合加密机制的FedAvg 前言 1. FedAvg 1.1 getData.py 1.2 Models.py 1.3 client.py 1.4 server. ...

  2. 【阅读笔记】联邦学习实战——联邦学习平台介绍

    前言 FATE是微众银行开发的联邦学习平台,是全球首个工业级的联邦学习开源框架,在github上拥有近4000stars,可谓是相当有名气的,该平台为联邦学习提供了完整的生态和社区支持,为联邦学习初学 ...

  3. 【赠书】重磅好书联邦学习实战来袭!你值得拥有一本

    我们以前给大家介绍过杨强教授团队所著的业界首本联邦学习的书籍,现在这本书的实战版来了,5月刚刚出版,本次给大家赠送3本新书,即<联邦学习实战>. 这是一本什么样的书 所谓联邦学习技术,是一 ...

  4. 【阅读笔记】联邦学习实战——联邦个性化推荐案例

    联邦学习实战--联邦个性化推荐案例 前言 1. 引言 2. 传统的集中式个性化推荐 2.1 矩阵分解 2.2 因子分解机 3. 联邦矩阵分解 3.1 算法详解 3.2 详细实现 4 联邦因子分解机 4 ...

  5. 【阅读笔记】联邦学习实战——联邦学习智能用工案例

    联邦学习实战--联邦学习智能用工案例 前言 1. 智能用工简介 2. 智能用工平台 2.1 智能用工的架构设计 2.2 智能用工的算法设计 3. 利用横向联邦提升智能用工模型 4. 设计联邦激励机制 ...

  6. 《联邦学习实战》杨强 读书笔记十七——联邦学习加速方法

    目录 同步参数更新的加速方法 增加通信间隔 减少传输内容 非对称的推送和获取 计算和传输重叠 异步参数更新的加速方法 基于模型集成的加速方法 One-Shot联邦学习 基于有监督的集成学习方法 基于半 ...

  7. 【阅读笔记】联邦学习实战——联邦学习攻防实战

    联邦学习实战--联邦学习攻防实战 前言 1. 后门攻击 1.1 问题定义 1.2 后门攻击策略 1.3 详细实现 2. 差分隐私 2.1 集中式差分隐私 2.2 联邦差分隐私 2.3 详细实现 3. ...

  8. 【阅读笔记】联邦学习实战——构建公平的大数据交易市场

    联邦学习实战--构建公平的大数据交易市场 前言 1. 大数据交易 1.1 数据交易定义 1.2 数据确权 1.3 数据定价 2. 基于联邦学习构建新一代大数据交易市场 3. 联邦学习激励机制助力数据交 ...

  9. 《联邦学习实战》杨强 读书笔记十四——构建公平的大数据交易市场

    当数据具有资产属性之后,数据便可以直接或者间接地为公司.为社会创造价值和收益,并且可以作为一种特殊的商品在市场中进行交易. 与传统的商品交易相比,数据资产交易的市场前景更广阔,但同时也面临着很多的挑战 ...

  10. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

最新文章

  1. Excel制作带勾的方框
  2. .NET技术学习目录整理
  3. [转载][总结]函数getopt(),getopt_long及其参数optind
  4. 运行中的Nginx进程间的关系
  5. Linux 命令之 ps -- 显示进程状态/查看进程信息
  6. [物理学与PDEs]第2章第4节 激波 4.2 熵条件
  7. lwip协议栈实现服务器端主动发送,《LwIP协议栈源码详解——TCP/IP协议的实现》IP层输入...
  8. 智头条:3月智能圈投融资大事记:极米、涂鸦上市,大华获中国移动56亿投资,凯迪仕获近1亿美元融资,小米投100亿美金造车
  9. 能删除Windows下“本地安装源 (Msocache)”吗?
  10. 怎样调整计算机视角,电脑调节不了CAD极轴角度怎样解决|电脑中调节CAD极轴角度的方法...
  11. 说一说我在创建星球这10多天,在星球里干了啥?
  12. ubuntu16.04 创建用户,赋予权限
  13. CouchDB安装与使用
  14. 最长公共子序列 LCS(模板) poj 1458
  15. 医疗康复机器人研究进展及趋势
  16. HPE总裁兼CEO接受《财富》杂志专访
  17. 有一种异性朋友叫温暖
  18. 2018年1月27日训练日记
  19. 运维利器之mysql进行表的分区
  20. JavaScript的数学计算库:decimal.js

热门文章

  1. html自动滚动列表,HTML滚动显示信息列表代码
  2. 南油 机器人_机器人创客-作文
  3. 笔记本电脑性价比排行2019_笔记本电脑性价比排行榜2020前十名
  4. 结对项目—第一次作业(俄罗斯方块)
  5. python和接码平台对接_python验证码识别接口及识别思路代码
  6. LOWORD, HIWORD, LOBYTE, HIBYTE
  7. dpdk-16.04 igb_uio 模块分析
  8. 以太网速率怎么手动设置_如何提高上网速度_直接设置网络接口跃点数就可以了 - 驱动管家...
  9. 方德操作系统自带jdk如何切换
  10. 自考管理系统中计算机应用可以不考吗,中南财大自考本科中科目“管理系统中计算机应用(实践)”能不能不考?...