No método Keras model.fit_generator (), que é a fila gerador de parâmetro controlado "max_q_size" utilizado?

votos
16

Eu construí um gerador simples que os rendimentos de tuple(inputs, targets)apenas com item único nos inputse targetslistas - basicamente rastejando o conjunto de dados, item de uma amostra de cada vez.

Eu passar este gerador em:

  model.fit_generator(my_generator(),
                      nb_epoch=10,
                      samples_per_epoch=1,
                      max_q_size=1  # defaults to 10
                      )

Entendi:

  • nb_epoch é o número de vezes que o lote formação será executado
  • samples_per_epoch é o número de amostras formados com por época

Mas o que é max_q_sizee por que seria o padrão para 10? Eu pensei que o propósito de usar um gerador foi para os dados do lote define em pedaços razoáveis, então por que a fila adicional?

Publicado 02/05/2016 em 15:07
fonte usuário
Em outras línguas...                            


1 respostas

votos
26

Isso simplesmente define o tamanho máximo da fila de treinamento interno que é usado para "precache" suas amostras do gerador. É usada durante a geração das as filas

def generator_queue(generator, max_q_size=10,
                    wait_time=0.05, nb_worker=1):
    '''Builds a threading queue out of a data generator.
    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    '''
    q = queue.Queue()
    _stop = threading.Event()

    def data_generator_task():
        while not _stop.is_set():
            try:
                if q.qsize() < max_q_size:
                    try:
                        generator_output = next(generator)
                    except ValueError:
                        continue
                    q.put(generator_output)
                else:
                    time.sleep(wait_time)
            except Exception:
                _stop.set()
                raise

    generator_threads = [threading.Thread(target=data_generator_task)
                         for _ in range(nb_worker)]

    for thread in generator_threads:
        thread.daemon = True
        thread.start()

    return q, _stop

Em outras palavras, você tem uma linha de enchimento da fila até dado, capacidade máxima diretamente de seu gerador, enquanto (por exemplo) rotina de treinamento consome seus elementos (e às vezes espera para a conclusão)

 while samples_seen < samples_per_epoch:
     generator_output = None
     while not _stop.is_set():
         if not data_gen_queue.empty():
             generator_output = data_gen_queue.get()
             break
         else:
             time.sleep(wait_time)

e por padrão de 10? Nenhuma razão particular, como a maioria dos padrões - ele simplesmente faz sentido, mas você poderia usar valores diferentes também.

Construção como este sugere, que os autores pensou em geradores de dados caros, o que pode levar tempo para execture. Por exemplo, considere o download de dados através de uma rede na chamada gerador - então faz sentido para precache alguns próximos lotes, e descarregue próximos queridos em paralelo por uma questão de eficiência e ser robusto a erros de rede etc.

Respondeu 02/05/2016 em 18:09
fonte usuário

Cookies help us deliver our services. By using our services, you agree to our use of cookies. Learn more