0

0

优化SciPy优化函数输入:解决矩阵维度不匹配与提升代码效率

花韻仙語

花韻仙語

发布时间:2025-07-28 20:24:12

|

946人浏览过

|

来源于php中文网

原创

优化SciPy优化函数输入:解决矩阵维度不匹配与提升代码效率

本文旨在解决使用scipy.optimize.fmin时因输入数组被展平导致的矩阵维度错误,并提供一套全面的优化方案。内容涵盖如何在目标函数内部重塑输入数组、利用NumPy向量化操作提升计算效率、推荐使用更现代且功能强大的scipy.optimize.minimize函数,并探讨针对特定问题采用线性代数解法的可能性,旨在帮助读者编写更健壮、高效的数值优化代码。

理解并解决scipy.optimize.fmin的输入维度问题

在使用scipy.optimize.fmin进行数值优化时,一个常见的陷阱是其默认行为会将传递给目标函数的初始猜测值(x0参数)展平(ravel)成一维数组,无论其原始形状如何。这导致当目标函数内部期望接收多维数组(如矩阵)时,会出现维度不匹配的错误,例如“matrices are different sizes”或“input operand 1 has a mismatch in its core dimension”。

问题根源: optimize.fmin为了通用性,将所有输入参数视为一维向量进行处理。因此,如果您传入一个 (4, 4) 的矩阵作为初始猜测 guess,在目标函数 objfunc 内部,guess 会变成一个 (16,) 的一维数组。

解决方案: 最直接的解决办法是在目标函数 objfunc 的开头,手动将展平后的 guess 数组重塑(reshape)回其预期的多维形状。

import numpy as np
from scipy import optimize
import math

# 定义矩阵维度
rows, cols = 4, 4

# 示例数据
guess = np.array([
    [1, -1, 2, 0],
    [0,  2, 0, 0],
    [1,  0, 0, 1],
    [0,  1, 2, 0]
])
inputArray = np.array([
    [2, 4, 6, 9],
    [2, 3, 1, 0],
    [7, 2, 6, 4],
    [1, 5, 2, 1]
])
goalArray = np.array([
    [14, 5, 17, 17],
    [4,  6, 2,   0],
    [3,  9, 8,  10],
    [16, 7, 13,  8]
])

# 修正后的目标函数
def objfunc(guess_flat, inputArray, goalArray):
    # 核心修正:将展平的guess_flat重塑回原始矩阵形状
    guess = guess_flat.reshape((rows, cols))

    model = guess @ inputArray # 矩阵乘法

    # 计算误差:使用NumPy向量化操作替代循环
    # 原始问题中是求差的绝对值之和,等同于L1范数
    # sum_error = np.sum(np.abs(goalArray - model))
    # 如果是欧几里得距离(L2范数),则使用:
    sum_error = np.linalg.norm(goalArray - model, ord='fro') # Frobenius范数等同于展平后的L2范数

    return sum_error

# 验证修正后的目标函数
print(f"Initial guess shape: {guess.shape}")
print(f"Input array shape: {inputArray.shape}")
print(f"Initial objective function value: {objfunc(guess.ravel(), inputArray, goalArray):.4f}")

# 调用optimize.fmin进行优化
# 注意:fmin的x0参数应为一维数组,因此需要对guess进行ravel()操作
minimum_fmin = optimize.fmin(objfunc, guess.ravel(), args=(inputArray, goalArray))

print("\n--- Results from optimize.fmin ---")
print(f"Optimized flat array: {minimum_fmin}")
print(f"Reshaped optimized matrix:\n{minimum_fmin.reshape((rows, cols))}")

优化目标函数:利用NumPy的向量化能力

原始的目标函数中使用了嵌套的 for 循环来计算误差。在处理NumPy数组时,应尽量避免显式的Python循环,因为NumPy提供了高效的向量化操作,可以显著提升计算性能。

改进前:

    sum = 0
    for i in range(rows):
        for j in range(cols):
            sum = sum + math.sqrt((goalArray[i][j] - model[i][j]) ** 2)

这段代码实际上是在计算矩阵元素差的绝对值之和(L1范数),因为 math.sqrt(x**2) 等价于 abs(x)。

改进后: 可以使用NumPy的 np.abs 和 np.sum 函数进行向量化操作,或者更推荐使用 np.linalg.norm 来清晰表达计算的是哪种范数。

  • 计算元素差的绝对值之和(L1范数):
    sum_error = np.sum(np.abs(goalArray - model))
  • 计算元素差的平方和的平方根(Frobenius范数,等同于展平后的L2范数):
    sum_error = np.sqrt(np.sum((goalArray - model) ** 2))
    # 或者更简洁、更清晰地使用 numpy.linalg.norm
    sum_error = np.linalg.norm(goalArray - model, ord='fro')

    在上述示例代码中,我们已经将目标函数 objfunc 更新为使用 np.linalg.norm,这不仅提高了效率,也使代码意图更加明确。

推荐的现代优化方法:scipy.optimize.minimize

scipy.optimize.fmin 是一个遗留函数,尽管仍可使用,但对于新代码的开发,官方推荐使用更强大、更灵活的 scipy.optimize.minimize 函数。minimize 提供了一个统一的接口来访问多种优化算法(如BFGS, Nelder-Mead, SLSQP等),并且其文档明确指出其 x0 参数必须是一维数组,这与 fmin 的内部行为保持一致。

讯飞听见会议
讯飞听见会议

科大讯飞推出的AI智能会议系统

下载

使用 optimize.minimize:

# 使用optimize.minimize进行优化
# x0 参数必须是展平的一维数组
minimum_result = optimize.minimize(objfunc, guess.ravel(), args=(inputArray, goalArray), method='BFGS')

print("\n--- Results from optimize.minimize (BFGS) ---")
print(f"Optimization successful: {minimum_result.success}")
print(f"Message: {minimum_result.message}")
print(f"Final objective function value: {minimum_result.fun:.4f}")
print(f"Optimized flat array: {minimum_result.x}")
print(f"Reshaped optimized matrix:\n{minimum_result.x.reshape((rows, cols))}")

# 尝试不同的优化方法,例如'Nelder-Mead' (fmin的默认方法)
minimum_result_nm = optimize.minimize(objfunc, guess.ravel(), args=(inputArray, goalArray), method='Nelder-Mead')
print("\n--- Results from optimize.minimize (Nelder-Mead) ---")
print(f"Optimization successful: {minimum_result_nm.success}")
print(f"Message: {minimum_result_nm.message}")
print(f"Final objective function value: {minimum_result_nm.fun:.4f}")
print(f"Reshaped optimized matrix:\n{minimum_result_nm.x.reshape((rows, cols))}")

optimize.minimize 的优势:

  • 统一接口: 通过 method 参数轻松切换不同的优化算法。
  • 更丰富的输出: 返回一个 OptimizeResult 对象,包含优化是否成功、迭代次数、梯度信息等详细结果。
  • 更清晰的文档: 对输入参数和算法行为有更明确的说明。
  • 处理非平滑目标函数: 对于像绝对值之和这类非平滑(不可导)的目标函数,Nelder-Mead 等不依赖梯度的算法可能更适用,而 BFGS 等梯度下降算法则需要目标函数是可导的。

特定问题的线性代数解决方案

值得注意的是,对于本教程中的特定问题——寻找一个转换矩阵 guess,使得 guess @ inputArray 尽可能接近 goalArray,如果 inputArray 是非奇异矩阵(可逆),这实际上是一个线性方程组问题,可以通过线性代数直接求解,而无需使用数值优化。

方程可以表示为 X * A = B,其中 X 是我们寻找的 guess 矩阵,A 是 inputArray,B 是 goalArray。 在Python/NumPy中,矩阵乘法的顺序通常是 X @ A。要解出 X,如果 A 可逆,则 X = B @ A_inv。

然而,numpy.linalg.solve(A, B) 函数用于求解 A @ X = B 形式的线性方程组。因此,我们需要对矩阵进行转置以适应 solve 函数的签名: guess @ inputArray = goalArray 两边同时右乘 inputArray 的逆矩阵: guess = goalArray @ np.linalg.inv(inputArray)

或者,使用 numpy.linalg.solve,它更数值稳定: 如果 A @ X = B,则 X = np.linalg.solve(A, B)。 我们的问题是 guess @ inputArray = goalArray。 为了匹配 A @ X = B 的形式,我们可以对整个方程进行转置: (guess @ inputArray).T = goalArray.TinputArray.T @ guess.T = goalArray.T 现在,这符合 A @ X = B 的形式,其中 A = inputArray.T,X = guess.T,B = goalArray.T。 因此,guess.T = np.linalg.solve(inputArray.T, goalArray.T)。 最后,将结果转置回来得到 guess: guess = np.linalg.solve(inputArray.T, goalArray.T).T

# 检查inputArray是否可逆
if np.linalg.det(inputArray) != 0:
    # 使用线性代数直接求解
    solution_linalg = np.linalg.solve(inputArray.T, goalArray.T).T
    print("\n--- Solution using Linear Algebra (np.linalg.solve) ---")
    print(f"Directly calculated transformation matrix:\n{solution_linalg}")
    # 验证解的准确性
    calculated_goal = solution_linalg @ inputArray
    print(f"Verifying result (solution_linalg @ inputArray):\n{calculated_goal}")
    print(f"Difference from goalArray (should be close to zero):\n{np.abs(goalArray - calculated_goal).sum():.4f}")
else:
    print("\ninputArray is singular, cannot solve directly using linear algebra.")

这种线性代数方法在 inputArray 非奇异的情况下,通常比迭代优化算法更精确、更快速。它应该作为解决此类特定问题的首选方法。

总结

在Python中使用SciPy进行数值优化时,理解优化函数对输入参数的处理方式至关重要。

  1. 重塑输入: scipy.optimize.fmin 或 scipy.optimize.minimize 会展平初始猜测值。在目标函数内部,务必将展平的数组重塑回其预期的多维形状。
  2. 向量化操作: 充分利用NumPy的向量化能力,用 np.sum, np.abs, np.linalg.norm 等函数替代显式循环,以提高目标函数的计算效率和可读性。
  3. 使用 optimize.minimize: 对于新的优化任务,优先选择 scipy.optimize.minimize。它提供了更现代、更灵活的接口,支持多种优化算法,并返回更详细的优化结果。
  4. 考虑问题本质: 在某些情况下,看似需要优化的任务可能通过更直接的数学方法(如线性代数求解)来解决,这通常会提供更精确和高效的答案。在开始复杂的数值优化之前,分析问题的数学性质是明智之举。

相关专题

更多
python开发工具
python开发工具

php中文网为大家提供各种python开发工具,好的开发工具,可帮助开发者攻克编程学习中的基础障碍,理解每一行源代码在程序执行时在计算机中的过程。php中文网还为大家带来python相关课程以及相关文章等内容,供大家免费下载使用。

715

2023.06.15

python打包成可执行文件
python打包成可执行文件

本专题为大家带来python打包成可执行文件相关的文章,大家可以免费的下载体验。

625

2023.07.20

python能做什么
python能做什么

python能做的有:可用于开发基于控制台的应用程序、多媒体部分开发、用于开发基于Web的应用程序、使用python处理数据、系统编程等等。本专题为大家提供python相关的各种文章、以及下载和课程。

739

2023.07.25

format在python中的用法
format在python中的用法

Python中的format是一种字符串格式化方法,用于将变量或值插入到字符串中的占位符位置。通过format方法,我们可以动态地构建字符串,使其包含不同值。php中文网给大家带来了相关的教程以及文章,欢迎大家前来阅读学习。

617

2023.07.31

python教程
python教程

Python已成为一门网红语言,即使是在非编程开发者当中,也掀起了一股学习的热潮。本专题为大家带来python教程的相关文章,大家可以免费体验学习。

1235

2023.08.03

python环境变量的配置
python环境变量的配置

Python是一种流行的编程语言,被广泛用于软件开发、数据分析和科学计算等领域。在安装Python之后,我们需要配置环境变量,以便在任何位置都能够访问Python的可执行文件。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

547

2023.08.04

python eval
python eval

eval函数是Python中一个非常强大的函数,它可以将字符串作为Python代码进行执行,实现动态编程的效果。然而,由于其潜在的安全风险和性能问题,需要谨慎使用。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

575

2023.08.04

scratch和python区别
scratch和python区别

scratch和python的区别:1、scratch是一种专为初学者设计的图形化编程语言,python是一种文本编程语言;2、scratch使用的是基于积木的编程语法,python采用更加传统的文本编程语法等等。本专题为大家提供scratch和python相关的文章、下载、课程内容,供大家免费下载体验。

697

2023.08.11

桌面文件位置介绍
桌面文件位置介绍

本专题整合了桌面文件相关教程,阅读专题下面的文章了解更多内容。

0

2025.12.30

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 0.6万人学习

Django 教程
Django 教程

共28课时 | 2.6万人学习

SciPy 教程
SciPy 教程

共10课时 | 0.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号