算法题-对KMP算法的理解

  1. KMP算法
  2. 找出字符串中第一个匹配项的下标

KMP算法

KMP算法可以求出模式串$p$在给定字符串$s$出现的位置 (注意, 下面的讨论中, 下标都从1开始).

KMP算法本质上是一个状态机的算法, 我现在有原串$s$的很多字符$c_1, c_2, …, c_n$, 要喂到这个状态机中, 这个状态机有这几个状态:

  • 状态0: 起始状态, 还没有成功匹配一个字符.
  • 状态1: 我喂进去了一个字符$c$, 并且成功匹配了.
  • 状态2: 我喂进去了两个字符$c_1, c_2$, 并且都成功匹配了.
  • 直到状态$m$ ($m$是模式串的长度), 表示我喂进去了$m$个字符, 都成功匹配了, 那就说明从原串中找到模式串了, 成功.

因此, 一共有$m + 1$个状态.


假设我现在在状态$i$, 有一个字符$c_{i + 1}$没匹配上, 那么应该转移到哪个状态?

假设应该转移到状态$j$, 那么$p[1…j] = p[i - j + 1, i]$, 并且$j$应该足够大, 那么转移到的状态就是以$s[i]$结尾的最长公共前后缀的长度, 这个长度就是$j$, (下标为1的前提下).

next数组: next[i]表示我如果在状态i, 遇到了没匹配的字符, 应该转移到什么状态, 其实就是模式串$p$以字符p[i]结尾的最长公共前后缀的长度:

  • ne[0] = 0
  • ne[1] = 0
  • 对于ne[i], 首先用字符p[i]喂入状态机, 假设转移到了状态j, 那么j就是以字符p[i - 1]结尾的最长公共前后缀的长度. 此时, 如果p[j + 1] == p[i], 那么以p[i]结尾的最长公共前后缀的长度就是j + 1.
#include <iostream>

using namespace std;

const int N = 1000010, M = 100010;

int n, m;
int ne[M];
char s[N], p[M];

int main() {

    cin >> m >> p + 1 >> n >> s + 1;

    for (int i = 2, j = 0; i <= m; i ++) {
      /* j记录以p[i - 1]结尾的最长公共前后缀的长度 */
        while (j && p[j + 1] != p[i]) j = ne[j];
        if (p[j + 1] == p[i]) j ++;
        ne[i] = j;
    }

    for (int i = 1, j = 0; i <= n; i ++) {
        while (j && p[j + 1] != s[i]) j = ne[j];
        if (p[j + 1] == s[i]) j ++;
        if (j == m) {
              /* i - (m - 1) - 1 , 最后减1是因为答案要求的下标需要从0, 但是运算的时候是1开始*/
            printf("%d ", i - m);
            j = ne[j];
        }
    }

    return 0;
}

KMP算法时间复杂度是$O(n + m)$, 空间复杂度是$O(m)$.

暴力算法时间复杂度是$O(nm)$, 空间复杂度是$O(1)$.

找出字符串中第一个匹配项的下标

https://leetcode.cn/problems/find-the-index-of-the-first-occurrence-in-a-string/

class Solution {
public:
    int strStr(string s, string p) {
        if (p.empty()) return 0;

        // 注意先获取原长度, 然后再把下标变成1
        int n = s.size(), m = p.size();
        s = ' ' + s, p = ' ' + p;

        // 注意next数组的长度是p长度+1
          // 注意, 这里不能用int next[m + 1], 一定不要在局部用变长数组
        vector<int> next(m + 1);

        for (int i = 2, j = 0; i <= m; i ++) {
            while (j && p[j + 1] != p[i]) j = next[j];
            if (p[j + 1] == p[i]) j ++;
            next[i] = j;
        }

        for (int i = 1, j = 0; i <= n; i ++) {
            while (j && p[j + 1] != s[i]) j = next[j];
            if (p[j + 1] == s[i]) j ++;
            if (j == m) return i - m;
        }
        return -1;
    }
};