UP | HOME

自动 CPS 变换

Table of Contents

前言

我最一开始听到 CPS 变换这个词是在王垠的博客里 (请求不要喷我),就是那篇他第一次宣传他的40行代码的文章。我当时什么都看不懂,所以没太注意,不过我也正在学程序语言方面的东西,不久我就在 EOPL 和 The Little Schemer 里面又见到了 CPS。我有点不服气,知道了 CPS 不过就是这么个东西,于是我也开始想自己重造王垠40行代码,然后我很惊讶地也花了 刚好一个星期写了出来,而且还基本上跟他的一模一样……

(注意,我不是引战)

毕竟可能每个人有自己的思考方式,我在这里只是分享一下我自己的思路,我写出这个 CPS 程序的经历。当然为了显得我稍微强一点的样子,我把中间许多非常蠢的错误都省略了。

我写出这个程序以后,又去看了那篇经典论文 http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.42.4772 这里有一个更友好一点的版本 http://matt.might.net/articles/cps-conversion/ 发现我的思路和他的完全不同,我反倒觉得我的思路很清楚,他的思路我要绕个弯才能看懂。

这篇文章不涉及什么是 CPS 变换,CPS 变换入门请参考 The Little Schemer,我直接就开始写 cps 函数了,我们的目标就是王垠的那40行代码,我也把我自己的程序的变量名什么的都改成了跟那段代码一样的格式,方便对照,当然也有一些地方不一样,懒得改了。

我采用 Racket 语言,就是因为用起来方便一些。就这样了。

简单的 CPS 变换

其实一个正确的 CPS 变换程序,只要学过一点点写解释器的人都会写,所以我就不细讲了,只是提供一个回忆,你如果要看下去的话最好是这个会自己写。我们先只考虑 λ-calculus 的语法,就只有3个分支,变量,λ,函数调用。先把最简单,没有经过优化的程序写出来。

主要的函数是 cps1,它有两个参数 exprctxexpr 就是一个 s-expression, ctx 是当前的 context,是一个 symbol 或 '(λ...),比如 (cps1 '(f x) '(λ(x) x)) 的值为 '(f x (λ(x) x))。所以 (cps expr) 函数 就是 (cps1 expr 'id)(cps1 expr '(λ(x) x)),我暂且用前者。

先把用到的辅助函数都列出来

(define (cps expr)
  (define (atom? x)
    (not (pair? x)))
  (define n -1)
  (define (fv)
    (set! n (add1 n))
    (string->symbol (string-append "v" (number->string n))))
  (define (cps1 expr ctx)
    ....)
  (cps1 expr 'id))

下面就开始写 cps1

(define (cps1 expr ctx)
  (match expr
    [(? atom? expr) ....]
    [`(λ(,x) ,body) ....]
    [`(,rator ,rand) ....]))

第一项,当 expr 为一个变量(比如 'x ) 或一个原始类型(比如 123)时,我们就直接把它返回,比方说,

(cps1 'x 'k) ==> '(k x)
(cps1 'x '(λ(x) ....))
==> '((λ(x) ....) x)

因此,

[(? atom? expr) `(,ctx ,expr)]

对于一个 λ 表达式,我们同样直接把它返回,只是先递归进函数体内,把函数体进行 CPS 变换。每个 λ 的 continuation 都为 'k,比如,

(cps1 '(λ(x) x) 'id)
==> '(id (λ(x k) (k x)))

因此,

[`(λ(,x) ,body)
 `(,ctx (λ(,x k) ,(cps1 body 'k)))]

处理函数的部分稍微复杂一些,CPS 输出应该先分别计算函数和参数,然后调用,举个例子,在我们最终要完成的代码里,应该大致是这样的,

(cps1 '((f a) (g b)) 'id)
==>
'(f a
    (λ(v0) (g b
              (λ(v1) (v0 v1 id)))))

v0(f a) 的结果,v1(g b) 的结果,然后调用。

我们先要实现的会产生一些多余代码,比如,

(cps1 '(f x) 'id)
==> '((λ(v0)
        ((λ(v1) (v0 v1 id))
         x))
      f)
<=> '(let ([v0 f])
       (let ([v1 x])
         (v0 v1 id)))

可以看出来,我们先分别 CPS 了 fx,然后调用 (vf vx ctx),其中 vfvx 是我们前面定义的 (fv) 产生的,这样保证了每次都不一样。(说实话 (gensym) 也是可以的)

[`(,rator ,rand)
 (define v-rator (fv))
 (define v-rand  (fv))
 (cps1 rator
       `(λ (,v-rator)
          ,(cps1 rand
                 `(λ (,v-rand)
                    (,v-rator ,v-rand ,ctx)))))]

于是就全部搞定了。

于是我就兴高采烈地试了一下这个程序,

(cps '((f a) (g b)))
==>
'((λ (v4)
   ((λ (v5)
      (v4
       v5
       (λ (v0)
         ((λ (v2)
            ((λ (v3)
               (v2
                v3
                (λ (v1) (v0 v1 id))))
             b))
          g))))
    a))
 f)

……

不过算作是个很好的开头吧。

最简 CPS 输出

其实下面才开始真正的任务,上面一节只是因为,市面上的程序都是分好几个函数,我要把它们合在一起。

上面的程序的问题就在于,当 ctx'(λ(v) ...v...),而且 expr 是一个 'x 之类时,输出应该为 '...x... 而不是 '((λ(v) ...v...) x),照 λ-calculus 的术语说就是产生了一个 beta-redex。

我们来观察一下现在我们的 CPS程序的3个分支产生的 ctx

case 1:如果是一个 atom,就产生 `(,ctx ,expr),这时 ctx 在函数的位置。

case 2:如果是 λ 表达式,ctx 也在函数的位置。

case 3:但如果是函数调用,这时 ctx 在参数的位置((vf vx ctx))

很显然,在参数位置时 ctx 是不可能被化简的,因为结果必须是 (vf vx k/id)(vf vx (λ(v?) ???)) 的形式。而在函数位置时是有可能化简的,当它是 λ 函数的时候。

所以 ctx 在函数位置(case 1,2)时应该和 case 3 统一一下。就是说,为了化简,我们把 `(λ(v?) ???)quasiquote 直接去掉,改成一个函数 (λ(v?) `???),调用它就相当于直接把函数体里面的 v? 替换掉了,比如

;; 原来的输出
'((λ(v0) (f v0))
  x)
;; 现在变成
((λ(v) `(f ,v))
 'x)
==> '(f x)

第二,如果 ctx'k/id,就改成 (λ(out) `(k/id ,out))

因为总共就两种情况: ctx 在函数位置和参数位置。我们不妨把 cps1 函数的 ctx 参数改成两个,一个叫 ctx-f 在函数位置,一个叫 ctx-a 在参数位置。

;; ctx-f : symbol -> s-exp
;; ctx-a : s-exp
(define (cps1 expr ctx-f ctx-a)
  (match expr
    ....))

(cps1 expr (λ(out) out) 'id)

(λ(out) out) 就是 id 这个函数。因为原来的 'id 可以看成是 `(λ(v?) v?),所以化简后 就变成了 (λ(out) out)

接下来就是 match 的分支。

[(? atom?) (ctx-f expr)]
[`(λ(,x) ,body)
 (ctx-f `(λ(,x k) ,(cps1 body
                         (λ(out) `(k ,out))
                         'k)))]

这两个应该都是好理解的,都调用了 ctx-f 来化简当前的式子。

但是在 CPS 函数调用时遇到了一些麻烦,因为我们发现总共有4种情况需要讨论,其中有大量重复的代码,但是不管怎么说,先把代码写出来才是正道 (以下代码会需要一点耐心)

[`(,rator ,rand)
 (define v-rator (fv))
 (define v-rand (fv))
 (cps1 rator
       (λ (out-rator)
         (cps1 rand
               (λ (out-rand)
                 `(,out-rator ,out-rand ,ctx-a))
               `(λ (,v-rand)
                  (,out-rator ,v-rand ,ctx-a))))
       `(λ (,v-rator)
          ,(cps1 rand
                 (λ (out-rand)
                   `(,v-rator ,out-rand ,ctx-a))
                 `(λ (,v-rand)
                    (,v-rator ,v-rand ,ctx-a)))))]

令人惊讶的是,这个 cps 函数就这么完成了! 跟前面的结果对照一下就会看出明显区别了。它不但可以处理 beta-redex,还能正确处理尾递归。

这段代码应该也不难理解,只是分别讨论了函数和参数分别处于函数位置和参数位置的情况。大致思路就是,首先,cps1 要根据 ratorrand 是否为一个 atom 来决定如何输出,其次,我们不愿意在递归进去之前就判断一次,递归进去之后又要 match expr (开头提到的那篇论文的方法就有这个问题)。所以我们把现在的状态分成了两个参数,也一起递归进去。

它有唯一一个但很好修复的缺陷,就是 v-ratorv-rand 定义地太早了,所以有时候会出现 vn 不连续的情况,如果不嫌麻烦的话可以在每次第一次出现 v-? 的地方再 (let ([v-? (fv)]) ....),当然这个代码看起来就…… 另外,做出了下面一道习题后也会很好修复这个缺陷。

论文里的方法大概就是这样,只是写成了好几个函数

[`(,rator ,rand)
 (if (atom? rator)
     (if (atom? rand)
         ....
         ....)
     (if (atom? rand)
         ....
         ....))]

其实这样写也是完全可以的,这时 cps1 只需要传一个参数 ctx-f 就可以了,在判断出不是 atom 以后用 'vn 调用 ctx-f,把它转换成 ctx-a,这其实更接近王垠的版本。

甚至还有一个写法,就是利用多返回值,再返回一个布尔值表示当前的选择,这个方法看起来会有些麻烦。就不提了。在本质上,这种多返回值的方法跟使用 continuation 的方法是完全等价的,你可以类比一下 JavaScript 里面的 Promise 和传统的回调函数是等价的。(我不喜欢 JavaScript,但是这个例子能明白的人应该比较多吧)

现在举两个例子,看一下这个程序是怎么做的,

(cps '(f x))
==> (cps1 'f ctx-f ctx-a)
==> (ctx-f 'f)
==> (cps1 'x .... ....)where[out-rator='f]
==> `(,out-rator ,out-rand id)where[out-rator='f out-rand='x]
==> '(f x id)

但如果没有 beta-redex 可以化简,比如,

(cps '((f a) b))
==> (cps1 '(f a) ctx-f ctx-a)
==> `(f a ,ctx-a)where[ctx-a='(λ(v0) (v0 b id))]
==> '(f a (λ(v0) (v0 b id)))

现在可以来看一下这段代码对我们有什么启发。

所谓的 continuation-passing style 多用一个参数 k 来告诉 我们要调用的函数当前的状态是什么,就是这个函数运行完了以后 应该干什么。但这里的 cps 函数也有一个参数 ctx,它也表示一个状态,它表示的是现在的状态,让更深层递归的函数能得知一些外部信息。

很多时候我们发现就传一个死的数据(比如第一个版本里的 ctx ) 是不够的,不但递归进去的函数需要这个数据,而这个数据也要随着当前的情况而变化。在简单的情况下我们可以传好几个参数,或者一个对象进去,里面的函数 选择性地使用这些数据。但是在支持高阶函数的语言里很多时候方便很多,因为我们可以传一个函数进去。

更通常的情况下,上面的 cps1 可以只有一个参数 ctx,其中 ctx 是这样的,

ctx=
(λ(position)
  (cond
    [(eq? position 'f) ....]
    [(eq? position 'a) ....]))

这两个分支分别为原来的 ctx-fctx-a

这种模式更广泛的应用之处在于 ctx 的参数不是一个用来选择的符号,而是一个连续数值或对象的时候。我一下子想不出实际的例子,但我感觉这种模式有不错的应用前景,有了例子我会补充。

习题:请扩展这个程序以支持多参数的 λ 和函数调用。

令我惊讶的是,支持多参数就不用分4类讨论了! 因为只要分两类讨论,依次遍历整个列表就可以,不用区分函数和参数,所以代码反倒简单多了。推荐做一下这个习题。

简化 cps1 函数

这一节,我们把 ctx-fctx-a 合并成一个 ctx

观察所有产生的 ctx-f/a 参数,总结一下总共有这些:

  1. λ(out) `(k ,out)
  2. 'k
  3. λ(out) ....`(....,out ....)
  4. `(λ(,vn) ....(....,vn ....))

如果要只传一个参数的话,我们会发现,由2可以推出1,因为我们只要给它包一个 λ 就可以了。由3可以推出4,如果3是 ctx,4就是 `(λ(,vn) ,(ctx vn))

问题就在于,1、3是一个形式的,2、4是一个形式的,我们要选择的就是只传1、3还是只传2、4.

我们发现,1、3是两个固定的值,而2、4里面是有一堆省略号的,也就是说,如果采用一点类似作弊的策略,从1也可以推出2,只要判断 ctx 是否等于 (λ(out) `(k ,out))。但是无论如何也不可能从任意的4推出3(当然你如果使用 eval 的话,我就没话说了,按理来说是可以的,你可以自己尝试一下,成功了记得偷偷告诉我一声)。

于是,我们决定采用1、3型的 ctx

先写一个转换的函数,把一个1、3型的 ctx 转成2、4型的。就是把 ctx-f 转成 ctx-a

(define (ctx-f->a ctx)
  (if (ctx1? ctx)
      'k
      (let ([v (fv)])
        `(λ(,v) ,(ctx v)))))

但是我们怎样知道一个 ctx 是不是 ctx1? 呢,就是不直接写出 (λ(out) `(k ,out)),而是定义

(define ctx1 (λ(out) `(k ,out)))

(看到了吗,这就是王垠 CPS 代码里的 ctx0 )

于是,

(define (ctx1? ctx) (eq? ctx ctx1))

这样完成了之后,原来那段代码所有的 ctx-a 都不用手写了,只要改为 (ctx-f->a ctx-f) 就可以了。因为能这样直接转化,所以也没有必要传两个 ctx 参数了,我们在需要用到 ctx-a 时现转化就可以,于是,我们最终得到了这样的代码。

(define (cps1 expr ctx)
  (match expr
    [(? atom? expr) (ctx expr)]
    [`(λ(,x) ,body)
     (ctx `(λ(,x k) ,(cps1 body ctx1)))]
    [`(,rator ,rand)
     (cps1 rator
           (λ (out-rator)
             (cps1 rand
                   (λ (out-rand)
                     `(,out-rator
                       ,out-rand
                       ,(ctx-f->a ctx))))))]))

是的,这就是完整的 λ-calculus 的最简的 CPS 变换,如果对照一下王垠的 CPS 变换的最后几行,你会发现我的这个版本甚至更清晰一些,因为我用 ctx-f->a 这个函数避免了 `(,out-rator ,out-rand ....) 这样重复的代码,并把判断也放进了这个辅助函数中。

你现在可以自己随意试验这个程序了。

下面我们对它进行一些扩展,先增加多参数的λ和函数调用,然后是原生的几个函数(比如 +-zero? 等),最后添加if语句。

多参数和原生函数

都已经到这一步了,支持多参数其实很简单。

第一种情况不用变。

[(? atom? expr) (ctx expr)]

第二种情况只要把 λ 的参数改成一个列表就可以了。

[`(λ ,args ,body)
 (ctx `(λ(,@args k) ,(cps1 body ctx1)))]

第三种情况甚至只要遍历列表就可以

[_ ; else : expr = `(,rator . ,rands)
 (let recur ([exprs expr] [acc '()])
   (if (null? exprs)
       `(,@acc ,(ctx-f->a ctx))
       (cps1 (car exprs)
             (λ(v) (recur (cdr exprs) `(,@acc ,v))))))]

说实话,最后这个递归…我也不知道我自己是怎么写出来的,希望有人能帮我改成普通的循环之类的,简化一下代码。思路就是这样,首先最后返回值肯定是
(cps1 (car exprs) (λ(v) ....))
然后省略号要填的是,递归遍历 (cdr exprs),所以结构必须是这样的,

(let recur ([exprs expr])
  (cps1 (car exprs) ; when exprs is not null
        (λ (v)
          (recur (cdr exprs))
          (process-v))))

然后当 exprs 最后变成 null 的时候,我们需要把之前所有的 v : v1 v2 v3 .... 收集起来,返回 `(,v1 ,v2 ,v3 .... ,vn ,(ctx-f->a ctx)),因此再多一个变量 acc,用来收集这些 v。这个程序就完成了。你可以自己试验一下确保它正确。

接下来我们要支持一些原生的函数,比如 + - * / zero?,这些函数不需要经过 CPS 变换,比如,

(cps '(+ x y))
;; instead of (+ x y id)
==> '(+ x y)

(cps '(+ (f x) y))
==> '(f x (λ(v0) (+ v0 y)))

(cps '(+ (* x y) z))
==> '(+ (* x y) z)

;; when used as higher order procedure
(cps '(((λ(m) +) n) ; returns +
       x y))
==> '((λ(m k)
        (k +))
      n
      (λ(v0) (v0 x y id)))

首先,

(define (trivial? x)
  (memq x '(zero? add1 sub1 + - * /)))

然后,处理函数的不分大多数不变,先在最后进行一个判断。

[_
 (let recur ([exprs expr] [acc '()])
  (if (null? exprs)
      (if (trivial? (car acc))
          ....
          `(,@acc ,(ctx-f->a ctx)))
      (cps1 (car exprs)
            (λ(v) (recur (cdr exprs) `(,@acc ,v))))))]

如果 (car acc) 是原生函数,就把整体当成一个值返回,所以非常简单,

(if (trivial? (car acc))
    (ctx acc)
    `(,@acc ,(ctx-f->a ctx)))

这样就完成了。现在只差最后的难点,就是 if 语句了。

if 语句

首先,if 语句的不同之处在于,我们需要对它的两个分支做两次 CPS 变换,比如,

(cps '(λ(x) (if a b (f c))))
==> '(λ(x k) (if a (k b) (f c k)))

是的,这个没什么难的,既然都到这里了,应该随手就可以写出来了,

[`(if ,test ,conseq ,alt)
 (cps1 test
       (λ(t) `(if ,t
                  ,(cps1 conseq ctx)
                  ,(cps1 alt ctx))))]

把这段代码插到 CPS 函数调用的前面。我一开始以为就这么完成了,结果发现,这里的 ctx 是会被翻倍的。

如果 ctxk 的话,没有关系,但是如果是一个比较复杂的式子,就是说,if 语句嵌套在了不是尾递归的地方,比如函数的参数,或 if 的判断语句,就出现了一些问题,

(cps '(λ(x) (f (if a b c))))
==> '(λ (x k) (if a (f b k) (f c k)))

看起来没什么毛病,但那是因为我们的 ctx 太简单了,制造一点更复杂的,

(cps '(λ(x) (f (g (h (if a b c))))))
==> '(λ (x k)
       (if a
           (h
            b
            (λ (v0) (g v0 (λ (v1) (f v1 k)))))
           (h
            c
            (λ (v2) (g v2 (λ (v3) (f v3 k)))))))


(cps '(λ(x) (if (if a b c) d e)))
==> '(λ (x k)
       (if a
           (if b (k d) (k e))
           (if c (k d) (k e))))

一眼看过去,长得一模一样的代码很多,第一个例子里面,(h b/c ....) 就只有 bc 不同,后面完全一样,第二个例子也是这样。

解决方法也很直接,在发现当前的 ctx 不是最简的时候,我们用一个 let 包住当前的 ctx,最终结果变成这样,

(cps '(λ(x) (f (if a b c))))
==> '(λ(x k)
       (let ([k (λ(v0) (f v0 k))])
         (if a (k b) (k c))))

(cps '(λ(x) (if (if a b c) d e)))
==> '(λ(x k)
       (let ([k (λ(v0) (if v0 (k d) (k e)))])
         (if a (k b) (k c))))

 (这看起来怎么像是 ANF 和 CPS 的混合版)

所以,也是只要加一个判断即可,

[`(if ,test ,conseq ,alt)
 (define (if-body ctx)
   (cps1 test
         (λ(t) `(if ,t
                    ,(cps1 conseq ctx)
                    ,(cps1 alt ctx)))))
 (if (ctx1? ctx)
     (if-body ctx)
     `(let ([k ,(ctx-f->a ctx)])
        ,(if-body ctx1)))]

好了,现在一切都已经完成了。这个教程可以基本圆满地结束了。

那么最后,还有一点小问题可以修复,就是在这种 ctxid 的时候,

(cps '(if a b c))
==> '(let ([k (λ(v0) v0)])
       (if a (k b) (k c)))

我们不太希望它只是 (let ([k id])),而希望直接输出 (if a b c),因此,我们把 id 也归为一类 ctx (这应该是第五类了?),先插入定义

(define id (λ(x) x))
(define (id? x) (eq? x id))

然后把处理 if 的代码加上一句,

(if (or (ctx1? ctx)
        (id? ctx))
    (if-body ctx)
    .... as before)

最后,

(define (cps expr)
  ....
  (cps1 expr id))

就可以了。

再最后,如果有人看着 (cps '(f x)) 变成 '(f x (λ(v0) v0)) 感觉不爽的话,可以这么改一下,

(define (ctx-f->a ctx)
  (cond
    [(ctx1? ctx) 'k]
    [(id? ctx) 'id]
    [else
     (define v (fv))
     `(λ(,v) ,(ctx v))]))
;; 话说这里用 case 语句会更舒服一点的...

这样,(cps '(f x)) 就是 '(f x id) 啦。

总结

这么多代码看下来,其实你会发现,就只有几个关键点,只要想到了,其实也没有多难。我自己想这个程序的时候,手头上没有电脑,我是写在纸上的(好痛苦啊),但是放到电脑上测试,一次性就全是对的,毕竟不是很大的工程,也没有各种复杂的角角落落需要考虑,思路还是很简单的。

这个程序还有升级空间,就是 beginset! 语句,提示一下,东西越来越复杂的时候,可能不得不回归到第3节中的方式,把各种 ctx 拆开,否则处理 set! 的时候会产生一堆嵌套的 begin 语句。另外,如果是 Common Lisp 里的那种有返回值的赋值语句,处理起来会简单一些,因为可以简单地看作一个表达式。

最后就随便说说,其实这段代码也没有特别的高级,只是自己写出来了,那就开心一下就好。代码里倒是有几个挺特别的想法值得学习。

文笔不好请见谅,有任何错误或写的不好的地方欢迎指出。

最后的最后,我把本篇文章完整的代码贴上,

(define (cps expr)
  (define (atom? x)
    (not (pair? x)))
  (define n -1)
  (define (fv)
    (set! n (add1 n))
    (string->symbol (string-append "v" (number->string n))))
  (define ctx1 (λ(out) `(k ,out)))
  (define (ctx1? ctx) (eq? ctx ctx1))
  (define (ctx-f->a ctx)
    (if (ctx1? ctx)
        'k
        (let ([v (fv)])
          `(λ(,v) ,(ctx v)))))
  (define (trivial? x)
    (memq x '(zero? add1 sub1 + - * /)))
  (define id (λ(x) x))
  (define (id? x) (eq? x id))
  (define (cps1 expr ctx)
    (match expr
      [(? atom?) (ctx expr)]
      [`(if ,test ,conseq ,alt)
       (define (if-body ctx)
         (cps1 test
               (λ(t) `(if ,t
                          ,(cps1 conseq ctx)
                          ,(cps1 alt ctx)))))
       (if (or (ctx1? ctx)
               (id? ctx))
           (if-body ctx)
           `(let ([k ,(ctx-f->a ctx)])
              ,(if-body ctx1)))]
      [`(λ ,args ,body)
       (ctx `(λ(,@args k) ,(cps1 body ctx1)))]
      [_
       (let recur ([exprs expr] [acc '()])
         (if (null? exprs)
             (if (trivial? (car acc))
                 (ctx acc)
                 `(,@acc ,(ctx-f->a ctx)))
             (cps1 (car exprs)
                   (λ(v) (recur (cdr exprs) `(,@acc ,v))))))]))
  (cps1 expr id))