Почему tf.get_variable ('test') возвращает переменную с именем test_1?

1

Я создал переменную tf.Variable с tf.Variable. Интересно, почему, если я вызываю tf.get_variable с тем же именем, исключение не создается и создается новая переменная с добавленным именем?

import tensorflow as tf

class QuestionTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')

    def test_variable(self):
        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertEqual(b.name, "test_1:0")

        self.assertNotEqual(a, b, msg=''a' is not 'b'')

        with self.assertRaises(ValueError) as ecm:
            tf.get_variable('test', shape=(), trainable=False)
        exception = ecm.exception
        self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")
Теги:
tensorflow

1 ответ

1
Лучший ответ

Это связано с тем, что tf.Variable - это метод низкого уровня, в котором хранится созданная переменная в коллекции GLOBALS (или LOCALS), в то время как tf.get_variable сохраняет учетную запись переменной, которую она создала, сохраняя их в хранилище переменных.

Когда вы сначала вызываете tf.Variable, созданная переменная не добавляется в хранилище переменных, позволяя думать, что никакой переменной с именем "test" не было создано.

Итак, когда вы позже tf.get_variable("test") он будет смотреть на хранилище переменных, см., Что в нем нет переменной с именем "test".
Таким образом, он вызовет tf.Variable, который создаст переменную с добавленным именем "test_1" хранящимся в хранилище переменных под ключом "test".

import tensorflow as tf

class AnswerTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')    

    def test_variable_answer(self):
        """Using the default variable scope"""
        # Let first check the __variable_store and the GLOBALS collections.
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [],
                             "No global variables")

        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")
        self.assertListEqual(tf.get_collection(("__variable_store",)), [],
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [a],
                             "but 'a' is in global variables.")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertNotEqual(a, b, msg=''a' is not 'b'')
        self.assertEqual(b.name, "test_1:0", msg="'b' name is not 'test'.")
        self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
                        "There is now a variable store.")
        var_store = tf.get_collection(("__variable_store",))[0]
        self.assertDictEqual(var_store._vars, {"test": b},
                             "and variable 'b' is in it.")
        self.assertListEqual(tf.global_variables(), [a, b],
                             "while 'a' and 'b' are in global variables.")

        with self.assertRaises(ValueError) as exception_context_manager:
            tf.get_variable('test', shape=(), trainable=False)
        exception = exception_context_manager.exception
        self.assertStartsWith(str(exception),
                              "Variable test already exists, disallowed.")

То же самое верно при использовании явной переменной.

    def test_variable_answer_with_variable_scope(self):
        """Using now a variable scope"""
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")

        with tf.variable_scope("my_scope") as scope:
            self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0, 
                            "There is now a variable store.")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "but with variable in it.")

            a = tf.Variable(0., trainable=False, name='test')
            self.assertEqual(a.name, "my_scope/test:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "Still no variable in the store.")


            b = tf.get_variable('test', shape=(), trainable=False)
            self.assertEqual(b.name, "my_scope/test_1:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(
                var_store._vars, {"my_scope/test": b},
                "'b' is in the store, but notice the difference between its name and its key in the store.")

            with self.assertRaises(ValueError) as exception_context_manager:
                tf.get_variable('test', shape=(), trainable=False)
            exception = exception_context_manager.exception
            self.assertStartsWith(str(exception),
                                  "Variable my_scope/test already exists, disallowed.")

Ещё вопросы

Сообщество Overcoder
Наверх
Меню