Tensorflow ne peut pas obtenir `image.shape` à partir de la méthode dans` dataset.map (mapFn) `

10

J'essaie de faire l' tensorflowéquivalent de torch.transforms.Resize(TRAIN_IMAGE_SIZE), ce qui redimensionne la plus petite dimension d'image TRAIN_IMAGE_SIZE. Quelque chose comme ça

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

La réponse simple est ici: Tensorflow : Recadrer la plus grande région carrée centrale de l'image

Mais quand j'utilise la méthode avec tf.data.Dataset.map(transforms), je reçois shape=(None,None,3)de l'intérieur largest_sq_crop(image). La méthode fonctionne bien quand je l'appelle normalement.

Michael
la source
1
Je pense que le problème vient du fait qu'ils EagerTensorsne sont pas disponibles à l'intérieur Dataset.map(), la forme est donc inconnue. Y at-il un travail autour?
michael
Pouvez-vous inclure la définition de largest_sq_crop?
jakub

Réponses:

1

J'ai trouvé la réponse. Cela avait à voir avec le fait que ma méthode de redimensionnement fonctionnait bien avec une exécution soignée, par exemple, tf.executing_eagerly()==Truemais échouait lorsqu'elle était utilisée à l'intérieur dataset.map(). Apparemment, dans cet environnement d'exécution, tf.executing_eagerly()==False.

Mon erreur était dans la façon dont je déballais la forme de l'image pour obtenir les dimensions pour la mise à l'échelle. L'exécution du graphique Tensorflow ne semble pas prendre en charge l'accès au tensor.shapetuple.

  # wrong
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # also wrong
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # but this works!!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

J'utilisais des dimensions de forme en aval dans ma dataset.map()fonction et cela a levé l'exception suivante car elle obtenait Noneau lieu d'une valeur.

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.

Lorsque je suis passé au déballage manuel de la forme tf.shape(), tout a bien fonctionné.

Michael
la source