Python / Keras - acessando ModelCheckpoint callback

votos
4

Estou usando Keras de prever uma série de tempo. Como padrão Eu estou usando 20 épocas. Quero saber o que fez minha rede neural prever para cada uma das 20 épocas.

Usando model.predict fico com a última previsão. No entanto eu quero todas as previsões, ou pelo menos os últimos 10 queridos (que têm níveis de erros aceitáveis).

Para acessar que eu estou tentando a função ModelCheckpoint de Keras, no entanto estou tendo problemas para acessá-lo depois. Eu estou usando o seguinte código:

model=Sequential()

model.add(GRU(input_dim=col,init='uniform',output_dim=20))
model.add(Dense(10))
model.add(Dense(5))
model.add(Activation(softmax))
model.add(Dense(1))

model.compile(loss=mae, optimizer=RMSprop)

checkpoint=ModelCheckpoint(filepath='/Users/Alex/checkpoint.hdf5')

model.fit(X=predictor_train, y=target_train, nb_epoch=20, batch_size=batch,validation_split=0.1) #best validation split at 0.1
model.evaluate(X=predictor_train, y=target_train,batch_size=batch,show_accuracy=True)

print checkpoint

Objetivamente, minhas perguntas são:

  • Eu esperava que depois de executar o código que eu iria encontrar um arquivo chamado checkpoint.hdf5 dentro da pasta / Users / Alex, porém eu não sabia. o que estou perdendo?

  • Quando eu imprimir checkpointo que eu vejo é um keras.callbacks.ModelCheckpoint object at 0x117471290. Existe uma maneira para imprimir o que eu quero? Como o código ficaria assim?

Sua ajuda é muito apreciado :)

Publicado 26/04/2016 em 20:15
fonte usuário
Em outras línguas...                            


1 respostas

votos
8

Há dois problemas neste código:

  • Você não está passando o retorno de chamada para método de ajuste do modelo. Isso é feito com o argumento de palavra-chave "callbacks".
  • O caminho de arquivo deve conter espaços reservados (como "{marcaram época: 02d} - {val_loss: .2f}" que são usados ​​com str.format por Keras a fim de salvar cada época para um arquivo diferente.

Assim, a versão correta deve ser algo como:

checkpoint = ModelCheckpoint(filepath='/Users/Alex/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5')

model.fit(X=predictor_train, y=target_train, nb_epoch=20,
         batch_size=batch,validation_split=0.1, callbacks=[checkpoint])

Você também pode adicionar outros tipos de chamadas de retorno na lista que é atribuído a essa palavra-chave.

Infelizmente o objeto de retorno de chamada não armazena as informações do histórico de modo que não podem ser recuperados a partir dele.

Respondeu 26/04/2016 em 21:57
fonte usuário

Cookies help us deliver our services. By using our services, you agree to our use of cookies. Learn more