动态规划之字符串编辑距离

动态规划之字符串编辑距离

问题描述 #

给定 2 个字符串 a, b. 编辑距离是将 a 转换为 b 的最少操作次数,操作只允许如下 3 种: 插入一个字符,例如:fj -> fxj 删除一个字符,例如:fxj -> fj 替换一个字符,例如:jyj -> fyj

函数原型:

func LevenshteinDis(str1, str2 string) int {
    ...
}

算法适用场景 #

  • 拼写检查
  • 输入联想
  • 语音识别
  • 论文检查
  • DNA分析

问题分析 #

假定函数edit_dis(stra, strb)表示,stra到strb的编辑距离。算法问题可以分为四种情况:

  1. edit_dis(0, 0) = 0
  2. edit_dis(0, strb) = len(strb)
  3. edit_dis(stra, strb) = len(stra)
  4. edit_dis(stra, strb) = ?

对于4th一般情况,没有办法直接给出求解方式,我们来分析edit_dis(stra+chara, strb+charb)可能的情况:

  1. stra能转成strb,那么只需要判断chara是不是等于charb (cur_cost = 0 if chara == charb else 1)
  2. stra+chara能转成strb, 那么要让stra + chara 转成strb+ charb, 只需要插入charb就行了
  3. 如果stra 可以直接转成strb+charb,那么删除chara就可以转换成功了

综上的分析,可以得到如下DP公式:

                    |-- 0, (i=0, j=0)
                    |-- j, (i=0, j>0)
edit_dis(i, j) =    |-- i, (i>0, j=0)
                    |-- min{edit_dis(i-1, j)+1, edit_dis(i, j-1)+1, edit_dis(i-1, j-1) + cur_cost}
                    # cur_cost = 0 if chara == charb else 1

到这里,完全可以开始动手写代码了。如果还不清楚,可以参考Levenshtein Distance, in Three Flavors,里面有详细步骤和分析

编码及测试 #

使用Golang编码LevenshteinDistance如下:

func LevenshteinDistance(source, dest string) int {
    var cols, rows int = len(source), len(dest)
    if cols == 0 {
        return rows
    }
    if rows == 0 {
        return cols
    }
    var ld *LD = &LD{Rows: rows, Cols: cols}
    ld.constructMatrix() // 初始化二维矩阵
    // PrintMatrix(ld.M)

    // step 5
    for c := 1; c <= cols; c++ {
        for r := 1; r <= rows; r++ {
            var cur_cost int = 1

            if source[c-1] == dest[r-1] {
                cur_cost = 0
            }
            // step 6
            cost := minOfThree(ld.M[r-1][c-1]+cur_cost, ld.M[r-1][c]+1, ld.M[r][c-1]+1)
            // step 7
            ld.setMatrix(cost, r, c)
        }
    }

    PrintMatrix(ld.M)

    return ld.M[rows][cols]
}

做了简单的划分,让算法看起来更清晰,这里是放入了一般情况的处理。

测试代码及测试结果截图:

// ...

func Test_PrintMatrix(t *testing.T) {
    m := Matrix{
        {1, 2, 3, 4},
        {1, 2, 3, 4},
        {1, 2, 3, 4},
        {1, 2, 3, 4},
        {1, 2, 3, 4},
        {1, 2, 3, 4},
        {1, 2, 3, 4},
    }
    PrintMatrix(m)
}

func Test_LD_ConstructMatrix(t *testing.T) {
    ld := &LD{Rows: 3, Cols: 6}
    ld.constructMatrix()

    if len(ld.M[0]) != 7 {
        t.Fatal("ConstructMatrix make a wrong matrix with cols")
    }
    if len(ld.M) != 4 {
        t.Fatal("ConstructMatrix make a wrong matrix with rows")
    }
    // display
    PrintMatrix(ld.M)
}

func Test_LevenshteinDistance(t *testing.T) {
    source := "GUMBO"
    dest := "GAMBOL"

    dis := LevenshteinDistance(source, dest)

    if dis != 2 {
        t.Fatalf("wrong dis is got, %d, actual: %d", dis, 2)
    }
}

func Test_LevenshteinDistance_case1(t *testing.T) {
    source := ""
    dest := "GAMBOL"

    dis := LevenshteinDistance(source, dest)

    if dis != 6 {
        t.Fatalf("wrong dis is got, %d, actual: %d", dis, 2)
    }
}

func Test_LevenshteinDistance_case2(t *testing.T) {
    source := "GUMBO"
    dest := ""

    dis := LevenshteinDistance(source, dest)

    if dis != 5 {
        t.Fatalf("wrong dis is got, %d, actual: %d", dis, 2)
    }
}

func Test_LevenshteinDistance_case3(t *testing.T) {
    source := "GUMBO"
    dest := "GUMBO"

    dis := LevenshteinDistance(source, dest)

    if dis != 0 {
        t.Fatalf("wrong dis is got, %d, actual: %d", dis, 2)
    }
}

func Test_LevenshteinDistance_case4(t *testing.T) {
    source := ""
    dest := ""

    dis := LevenshteinDistance(source, dest)

    if dis != 0 {
        t.Fatalf("wrong dis is got, %d, actual: %d", dis, 2)
    }
}

测试结果

参考资料 #

代码地址 #

github.com/yeqown/alg/dp/levenshtein_dis.go github.com/yeqown/alg/dp/levenshtein_dis_test.go

访问量 访客数