博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tf 从RNN到BERT
阅读量:2134 次
发布时间:2019-04-30

本文共 1312 字,大约阅读时间需要 4 分钟。

数据初始化
import tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras.layers import *((x_train, y_train), (x_test, y_test)) = keras.datasets.mnist.load_data()x_train = x_train.reshape(60000, -1)y_train = keras.utils.np_utils.to_categorical(y_train)
SimpleRNN
model1 = keras.Sequential()model1.add(Embedding(input_dim=256, output_dim=5))model1.add(SimpleRNN(units=2))model1.add(Dense(10))model1.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])model1.fit(x_train, y_train, batch_size=10)
GRU
model2 = keras.Sequential()model2.add(Embedding(input_dim=256, output_dim=5))model2.add(GRU(units=2))model2.add(Dense(10))model2.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])model2.fit(x_train, y_train, batch_size=10)
LSTM
model3 = keras.Sequential()model3.add(Embedding(input_dim=256, output_dim=5))model3.add(LSTM(units=2))model3.add(Dense(10))model3.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])model3.fit(x_train, y_train, batch_size=10)
encoder-decoder
model3 = keras.Sequential()model3.add(Embedding(input_dim=256, output_dim=5))model3.add(LSTM(units=2))model3.add(Dense(10))model3.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])model3.fit(x_train, y_train, batch_size=10)

转载地址:http://jlugf.baihongyu.com/

你可能感兴趣的文章
【LEETCODE】312-Burst Balloons
查看>>
【LEETCODE】232-Implement Queue using Stacks
查看>>
【LEETCODE】225-Implement Stack using Queues
查看>>
【LEETCODE】155-Min Stack
查看>>
【LEETCODE】20-Valid Parentheses
查看>>
【LEETCODE】290-Word Pattern
查看>>
【LEETCODE】36-Valid Sudoku
查看>>
【LEETCODE】205-Isomorphic Strings
查看>>
【LEETCODE】204-Count Primes
查看>>
【LEETCODE】228-Summary Ranges
查看>>
【LEETCODE】27-Remove Element
查看>>
【LEETCODE】66-Plus One
查看>>
【LEETCODE】26-Remove Duplicates from Sorted Array
查看>>
【LEETCODE】118-Pascal's Triangle
查看>>
【LEETCODE】119-Pascal's Triangle II
查看>>
【LEETCODE】88-Merge Sorted Array
查看>>
【LEETCODE】19-Remove Nth Node From End of List
查看>>
【LEETCODE】125-Valid Palindrome
查看>>
【LEETCODE】28-Implement strStr()
查看>>
【LEETCODE】6-ZigZag Conversion
查看>>