使用“ tfp.experiment.mcmc.particle_filter”的用法

发布于 2025-01-21 23:17:19 字数 7194 浏览 0 评论 0原文

我尝试使用Tensorflow提供的粒子过滤器,因为它能够与其他张量函数合并。但是,我找不到” Nofollow noreferrer“> tensorflow文档, tfp.experiment.mcmc.particle_filter,我也无法在其他网站上找到有关其使用情况的很多信息。

以下是我的简单代码,我遵循文档中的数据类型,并在某些虚拟数据中馈送以查看是否有效。

tfp.experimental.mcmc.particle_filter(
    observations=[101, np.ones([101, 1]), np.shape(np.ones([101, 1]))],
    initial_state_prior=tfd.Uniform(0, 100),
    transition_fn=tfd.Uniform(0, 100),
    observation_fn=tfd.Uniform(0, 100),
    num_particles=1000,
)

但是,以下错误不断弹出。希望有人可以指出我在哪里做错了。

提前致谢。

----> 4 tfp.experimental.mcmc.particle_filter(
      5     observations=[101, np.ones([101, 1]), np.shape(np.ones([101, 1]))],
      6     initial_state_prior=tfd.Uniform(0, 100),
      7     transition_fn=tfd.Uniform(0, 100),
      8     observation_fn=tfd.Uniform(0, 100),
      9     num_particles=1000,
     10 
     11     # initial_state_proposal=None, proposal_fn=None,
     12     # resample_fn=tfp.experimental.mcmc.resample_systematic,
     13     # resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold,
     14     # unbiased_gradients=True, rejuvenation_kernel_fn=None,
     15     # num_transitions_per_observation=1, trace_fn=_default_trace_fn,
     16     # trace_criterion_fn=_always_trace, static_trace_allocation_size=None,
     17     # parallel_iterations=1, seed=None, name=None
     18 )

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:357, in particle_filter(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal, proposal_fn, resample_fn, resample_criterion_fn, unbiased_gradients, rejuvenation_kernel_fn, num_transitions_per_observation, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, seed, name)
    354   static_trace_allocation_size = 0
    355   trace_criterion_fn = never_trace
--> 357 initial_weighted_particles = _particle_filter_initial_weighted_particles(
    358     observations=observations,
    359     observation_fn=observation_fn,
    360     initial_state_prior=initial_state_prior,
    361     initial_state_proposal=initial_state_proposal,
    362     num_particles=num_particles,
    363     seed=init_seed)
    364 propose_and_update_log_weights_fn = (
    365     _particle_filter_propose_and_update_log_weights_fn(
    366         observations=observations,
   (...)
    369         observation_fn=observation_fn,
    370         num_transitions_per_observation=num_transitions_per_observation))
    372 kernel = smc_kernel.SequentialMonteCarlo(
    373     propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
    374     resample_fn=resample_fn,
    375     resample_criterion_fn=resample_criterion_fn,
    376     unbiased_gradients=unbiased_gradients)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:431, in _particle_filter_initial_weighted_particles(observations, observation_fn, initial_state_prior, initial_state_proposal, num_particles, seed)
    426 initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0)
    428 # Return particles weighted by the initial observation.
    429 return smc_kernel.WeightedParticles(
    430     particles=initial_state,
--> 431     log_weights=initial_log_weights + _compute_observation_log_weights(
    432         step=0,
    433         particles=initial_state,
    434         observations=observations,
    435         observation_fn=observation_fn))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:516, in _compute_observation_log_weights(step, particles, observations, observation_fn, num_transitions_per_observation)
    510 step_has_observation = (
    511     # The second of these conditions subsumes the first, but both are
    512     # useful because the first can often be evaluated statically.
    513     ps.equal(num_transitions_per_observation, 1) |
    514     ps.equal(step % num_transitions_per_observation, 0))
    515 observation_idx = step // num_transitions_per_observation
--> 516 observation = tf.nest.map_structure(
    517     lambda x, step=step: tf.gather(x, observation_idx), observations)
    519 log_weights = observation_fn(step, particles).log_prob(observation)
    520 return tf.where(step_has_observation,
    521                 log_weights,
    522                 tf.zeros_like(log_weights))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/nest.py:914, in map_structure(func, *structure, **kwargs)
    910 flat_structure = (flatten(s, expand_composites) for s in structure)
    911 entries = zip(*flat_structure)
    913 return pack_sequence_as(
--> 914     structure[0], [func(*x) for x in entries],
    915     expand_composites=expand_composites)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/nest.py:914, in <listcomp>(.0)
    910 flat_structure = (flatten(s, expand_composites) for s in structure)
    911 entries = zip(*flat_structure)
    913 return pack_sequence_as(
--> 914     structure[0], [func(*x) for x in entries],
    915     expand_composites=expand_composites)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:517, in _compute_observation_log_weights.<locals>.<lambda>(x, step)
    510 step_has_observation = (
    511     # The second of these conditions subsumes the first, but both are
    512     # useful because the first can often be evaluated statically.
    513     ps.equal(num_transitions_per_observation, 1) |
    514     ps.equal(step % num_transitions_per_observation, 0))
    515 observation_idx = step // num_transitions_per_observation
    516 observation = tf.nest.map_structure(
--> 517     lambda x, step=step: tf.gather(x, observation_idx), observations)
    519 log_weights = observation_fn(step, particles).log_prob(observation)
    520 return tf.where(step_has_observation,
    521                 log_weights,
    522                 tf.zeros_like(log_weights))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:7186, in raise_from_not_ok_status(e, name)
   7184 def raise_from_not_ok_status(e, name):
   7185   e.message += (" name: " + name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None

InvalidArgumentError: params must be at least 1 dimensional [Op:GatherV2]

I tried to use the particle filter provided by Tensorflow due to its capability to incorporate with other Tensorflow functions. However, I couldn't quite find the example usage on Tensorflow Documentation of tfp.experimental.mcmc.particle_filter, nor can I find much information about the usage of it on other websites.

The following is my simple code, I follow the data type from the documentation and feed in some dummy data to see whether it works.

tfp.experimental.mcmc.particle_filter(
    observations=[101, np.ones([101, 1]), np.shape(np.ones([101, 1]))],
    initial_state_prior=tfd.Uniform(0, 100),
    transition_fn=tfd.Uniform(0, 100),
    observation_fn=tfd.Uniform(0, 100),
    num_particles=1000,
)

However, the following error keeps popping up. Hope someone can point out where did I do wrong.

Thanks in advance.

----> 4 tfp.experimental.mcmc.particle_filter(
      5     observations=[101, np.ones([101, 1]), np.shape(np.ones([101, 1]))],
      6     initial_state_prior=tfd.Uniform(0, 100),
      7     transition_fn=tfd.Uniform(0, 100),
      8     observation_fn=tfd.Uniform(0, 100),
      9     num_particles=1000,
     10 
     11     # initial_state_proposal=None, proposal_fn=None,
     12     # resample_fn=tfp.experimental.mcmc.resample_systematic,
     13     # resample_criterion_fn=tfp.experimental.mcmc.ess_below_threshold,
     14     # unbiased_gradients=True, rejuvenation_kernel_fn=None,
     15     # num_transitions_per_observation=1, trace_fn=_default_trace_fn,
     16     # trace_criterion_fn=_always_trace, static_trace_allocation_size=None,
     17     # parallel_iterations=1, seed=None, name=None
     18 )

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:357, in particle_filter(observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal, proposal_fn, resample_fn, resample_criterion_fn, unbiased_gradients, rejuvenation_kernel_fn, num_transitions_per_observation, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, seed, name)
    354   static_trace_allocation_size = 0
    355   trace_criterion_fn = never_trace
--> 357 initial_weighted_particles = _particle_filter_initial_weighted_particles(
    358     observations=observations,
    359     observation_fn=observation_fn,
    360     initial_state_prior=initial_state_prior,
    361     initial_state_proposal=initial_state_proposal,
    362     num_particles=num_particles,
    363     seed=init_seed)
    364 propose_and_update_log_weights_fn = (
    365     _particle_filter_propose_and_update_log_weights_fn(
    366         observations=observations,
   (...)
    369         observation_fn=observation_fn,
    370         num_transitions_per_observation=num_transitions_per_observation))
    372 kernel = smc_kernel.SequentialMonteCarlo(
    373     propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
    374     resample_fn=resample_fn,
    375     resample_criterion_fn=resample_criterion_fn,
    376     unbiased_gradients=unbiased_gradients)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:431, in _particle_filter_initial_weighted_particles(observations, observation_fn, initial_state_prior, initial_state_proposal, num_particles, seed)
    426 initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0)
    428 # Return particles weighted by the initial observation.
    429 return smc_kernel.WeightedParticles(
    430     particles=initial_state,
--> 431     log_weights=initial_log_weights + _compute_observation_log_weights(
    432         step=0,
    433         particles=initial_state,
    434         observations=observations,
    435         observation_fn=observation_fn))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:516, in _compute_observation_log_weights(step, particles, observations, observation_fn, num_transitions_per_observation)
    510 step_has_observation = (
    511     # The second of these conditions subsumes the first, but both are
    512     # useful because the first can often be evaluated statically.
    513     ps.equal(num_transitions_per_observation, 1) |
    514     ps.equal(step % num_transitions_per_observation, 0))
    515 observation_idx = step // num_transitions_per_observation
--> 516 observation = tf.nest.map_structure(
    517     lambda x, step=step: tf.gather(x, observation_idx), observations)
    519 log_weights = observation_fn(step, particles).log_prob(observation)
    520 return tf.where(step_has_observation,
    521                 log_weights,
    522                 tf.zeros_like(log_weights))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/nest.py:914, in map_structure(func, *structure, **kwargs)
    910 flat_structure = (flatten(s, expand_composites) for s in structure)
    911 entries = zip(*flat_structure)
    913 return pack_sequence_as(
--> 914     structure[0], [func(*x) for x in entries],
    915     expand_composites=expand_composites)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/nest.py:914, in <listcomp>(.0)
    910 flat_structure = (flatten(s, expand_composites) for s in structure)
    911 entries = zip(*flat_structure)
    913 return pack_sequence_as(
--> 914     structure[0], [func(*x) for x in entries],
    915     expand_composites=expand_composites)

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow_probability/python/experimental/mcmc/particle_filter.py:517, in _compute_observation_log_weights.<locals>.<lambda>(x, step)
    510 step_has_observation = (
    511     # The second of these conditions subsumes the first, but both are
    512     # useful because the first can often be evaluated statically.
    513     ps.equal(num_transitions_per_observation, 1) |
    514     ps.equal(step % num_transitions_per_observation, 0))
    515 observation_idx = step // num_transitions_per_observation
    516 observation = tf.nest.map_structure(
--> 517     lambda x, step=step: tf.gather(x, observation_idx), observations)
    519 log_weights = observation_fn(step, particles).log_prob(observation)
    520 return tf.where(step_has_observation,
    521                 log_weights,
    522                 tf.zeros_like(log_weights))

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/.conda/envs/tensorflow/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:7186, in raise_from_not_ok_status(e, name)
   7184 def raise_from_not_ok_status(e, name):
   7185   e.message += (" name: " + name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None

InvalidArgumentError: params must be at least 1 dimensional [Op:GatherV2]

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文